40 #ifndef __VMML__TUCK3_IMPORTER__HPP__
41 #define __VMML__TUCK3_IMPORTER__HPP__
43 #include <vmmlib/qtucker3_tensor.hpp>
54 template<
size_t R1,
size_t R2,
size_t R3,
size_t I1,
size_t I2,
size_t I3,
typename T_value =
float,
typename T_coeff =
float >
59 typedef float T_internal;
69 typedef typename u1_type::iterator u1_iterator;
70 typedef typename u1_type::const_iterator u1_const_iterator;
73 typedef typename u2_type::iterator u2_iterator;
74 typedef typename u2_type::const_iterator u2_const_iterator;
77 typedef typename u3_type::iterator u3_iterator;
78 typedef typename u3_type::const_iterator u3_const_iterator;
84 template<
typename T >
85 static void import_from(
const std::vector< T >& data_,
tucker3_type& tuck3_data_ );
88 static void import_quantized_from(
const std::vector<unsigned char>& data_in_,
qtucker3_type& tuck3_data_ );
93 static void import_hot_quantized_from(
const std::vector<unsigned char>& data_in_,
qtucker3_type& tuck3_data_ );
96 static void import_ttm_quantized_from(
const std::vector<unsigned char>& data_in_,
qtucker3_type& tuck3_data_ );
101 #define VMML_TEMPLATE_STRING template< size_t R1, size_t R2, size_t R3, size_t I1, size_t I2, size_t I3, typename T_value, typename T_coeff >
102 #define VMML_TEMPLATE_CLASSNAME tucker3_importer< R1, R2, R3, I1, I2, I3, T_value, T_coeff >
106 template<
typename T >
108 VMML_TEMPLATE_CLASSNAME::import_from(
const std::vector< T >& data_, tucker3_type& tuck3_data_ )
111 size_t data_size = (size_t) data_.size();
113 if ( data_size != tuck3_data_.SIZE )
114 VMMLIB_ERROR(
"import_from: the input data must have the size R1xR2xR3 + R1xI1 + R2xI2 + R3xI3 ", VMMLIB_HERE );
116 u1_type* u1 =
new u1_type;
117 u2_type* u2 =
new u2_type;
118 u3_type* u3 =
new u3_type;
121 tuck3_data_.get_u1( *u1 );
122 tuck3_data_.get_u2( *u2 );
123 tuck3_data_.get_u3( *u3 );
124 tuck3_data_.get_core( core );
126 u1_iterator it = u1->begin(),
128 for( ; it != it_end; ++it, ++i )
130 *it =
static_cast< T
>( data_.at(i));
134 u2_iterator u2_it = u2->begin(),
135 u2_it_end = u2->end();
136 for( ; u2_it != u2_it_end; ++u2_it, ++i )
138 *u2_it =
static_cast< T
>( data_.at(i));
141 u3_iterator u3_it = u3->begin(),
142 u3_it_end = u3->end();
143 for( ; u3_it != u3_it_end; ++u3_it, ++i )
145 *u3_it =
static_cast< T
>( data_.at(i));
148 t3_core_iterator it_core = core.begin(),
149 it_core_end = core.end();
150 for( ; it_core != it_core_end; ++it_core, ++i )
152 *it_core =
static_cast< T
>( data_.at(i));
155 tuck3_data_.set_u1( *u1 );
156 tuck3_data_.set_u2( *u2 );
157 tuck3_data_.set_u3( *u3 );
158 tuck3_data_.set_core( core );
160 tuck3_data_.cast_comp_members();
170 VMML_TEMPLATE_CLASSNAME::import_quantized_from(
const std::vector<unsigned char>& data_in_, qtucker3_type& tuck3_data_ )
173 size_t len_t_comp =
sizeof( T_internal );
174 size_t len_export_data = tuck3_data_.SIZE *
sizeof(T_coeff) + 8 * len_t_comp;
175 unsigned char * data =
new unsigned char[ len_export_data ];
176 for(
size_t byte = 0; byte < len_export_data; ++byte )
178 data[byte] = data_in_.at(byte);
182 T_internal u_min = 0; T_internal u_max = 0;
183 memcpy( &u_min, data, len_t_comp ); end_data = len_t_comp;
184 memcpy( &u_max, data + end_data, len_t_comp ); end_data += len_t_comp;
186 T_internal core_min = 0; T_internal core_max = 0;
187 memcpy( &core_min, data + end_data, len_t_comp ); end_data += len_t_comp;
188 memcpy( &core_max, data + end_data, len_t_comp ); end_data += len_t_comp;
190 u1_type* u1 =
new u1_type;
191 u2_type* u2 =
new u2_type;
192 u3_type* u3 =
new u3_type;
195 tuck3_data_.get_u1( *u1 );
196 tuck3_data_.get_u2( *u2 );
197 tuck3_data_.get_u3( *u3 );
198 tuck3_data_.get_core( core );
201 size_t len_u1 = I1 * R1 *
sizeof( T_coeff );
202 memcpy( *u1, data + end_data, len_u1 ); end_data += len_u1;
205 size_t len_u2 = I2 * R2 *
sizeof( T_coeff );
206 memcpy( *u2, data + end_data, len_u2 ); end_data += len_u2;
209 size_t len_u3 = I3 * R3 *
sizeof( T_coeff );
210 memcpy( *u3, data + end_data, len_u3 ); end_data += len_u3;
213 size_t len_core_slice = R1 * R2 *
sizeof( T_coeff );
214 front_core_slice_type* slice =
new front_core_slice_type();
215 for (
size_t r3 = 0; r3 < R3; ++r3 ) {
216 memcpy( slice, data + end_data, len_core_slice );
217 core.set_frontal_slice_fwd( r3, *slice );
218 end_data += len_core_slice;
221 tuck3_data_.set_u1( *u1 );
222 tuck3_data_.set_u2( *u2 );
223 tuck3_data_.set_u3( *u3 );
224 tuck3_data_.set_core( core );
227 tuck3_data_.dequantize_basis_matrices( u_min, u_max, u_min, u_max, u_min, u_max );
228 tuck3_data_.dequantize_core( core_min, core_max );
241 VMML_TEMPLATE_CLASSNAME::import_hot_quantized_from(
const std::vector<unsigned char>& data_in_, qtucker3_type& tuck3_data_ )
243 tuck3_data_.enable_quantify_hot();
245 size_t len_t_comp =
sizeof( T_internal );
246 size_t len_export_data = R1*R2*R3 + (R1*I1 + R2*I2 + R3*I3) *
sizeof(T_coeff) + 4 * len_t_comp;
247 unsigned char * data =
new unsigned char[ len_export_data ];
248 for(
size_t byte = 0; byte < len_export_data; ++byte )
250 data[byte] = data_in_.at(byte);
254 T_internal u_min = 0; T_internal u_max = 0;
255 memcpy( &u_min, data, len_t_comp ); end_data = len_t_comp;
256 memcpy( &u_max, data + end_data, len_t_comp ); end_data += len_t_comp;
258 T_internal core_min = 0; T_internal core_max = 0;
260 memcpy( &core_max, data + end_data, len_t_comp ); end_data += len_t_comp;
262 T_internal hottest_value = 0;
263 memcpy( &hottest_value, data + end_data, len_t_comp ); end_data += len_t_comp;
264 tuck3_data_.set_hottest_value( hottest_value );
266 u1_type* u1 =
new u1_type;
267 u2_type* u2 =
new u2_type;
268 u3_type* u3 =
new u3_type;
270 t3_core_signs_type signs;
272 tuck3_data_.get_u1( *u1 );
273 tuck3_data_.get_u2( *u2 );
274 tuck3_data_.get_u3( *u3 );
275 tuck3_data_.get_core( core );
276 tuck3_data_.get_core_signs( signs );
279 size_t len_u1 = I1 * R1 *
sizeof( T_coeff );
280 memcpy( *u1, data + end_data, len_u1 ); end_data += len_u1;
283 size_t len_u2 = I2 * R2 *
sizeof( T_coeff );
284 memcpy( *u2, data + end_data, len_u2 ); end_data += len_u2;
287 size_t len_u3 = I3 * R3 *
sizeof( T_coeff );
288 memcpy( *u3, data + end_data, len_u3 ); end_data += len_u3;
291 size_t len_core_el = 1;
293 unsigned char core_el;
294 for (
size_t r3 = 0; r3 < R3; ++r3 ) {
295 for (
size_t r2 = 0; r2 < R2; ++r2 ) {
296 for (
size_t r1 = 0; r1 < R1; ++r1 ) {
297 memcpy( &core_el, data + end_data, len_core_el );
298 signs.at( r1, r2, r3 ) = (core_el & 0x80)/128;
299 core.at( r1, r2, r3 ) = core_el & 0x7f ;
305 tuck3_data_.set_u1( *u1 );
306 tuck3_data_.set_u2( *u2 );
307 tuck3_data_.set_u3( *u3 );
308 tuck3_data_.set_core( core );
309 tuck3_data_.set_core_signs( signs );
312 tuck3_data_.dequantize_basis_matrices( u_min, u_max, u_min, u_max, u_min, u_max );
313 tuck3_data_.dequantize_core( core_min, core_max );
324 VMML_TEMPLATE_CLASSNAME::import_ttm_quantized_from(
const std::vector<unsigned char>& data_in_, qtucker3_type& tuck3_data_ )
326 tuck3_data_.enable_quantify_log();
328 size_t len_t_comp =
sizeof( T_internal );
329 size_t len_export_data = R1*R2*R3 + (R1*I1 + R2*I2 + R3*I3) *
sizeof(T_coeff) + 3 * len_t_comp;
330 unsigned char * data =
new unsigned char[ len_export_data ];
331 for(
size_t byte = 0; byte < len_export_data; ++byte )
333 data[byte] = data_in_.at(byte);
337 T_internal u_min = 0; T_internal u_max = 0;
338 memcpy( &u_min, data, len_t_comp ); end_data = len_t_comp;
339 memcpy( &u_max, data + end_data, len_t_comp ); end_data += len_t_comp;
341 T_internal core_min = 0; T_internal core_max = 0;
343 memcpy( &core_max, data + end_data, len_t_comp ); end_data += len_t_comp;
345 u1_type* u1 =
new u1_type;
346 u2_type* u2 =
new u2_type;
347 u3_type* u3 =
new u3_type;
349 t3_core_signs_type signs;
351 tuck3_data_.get_u1( *u1 );
352 tuck3_data_.get_u2( *u2 );
353 tuck3_data_.get_u3( *u3 );
354 tuck3_data_.get_core( core );
355 tuck3_data_.get_core_signs( signs );
358 size_t len_u1 = I1 * R1 *
sizeof( T_coeff );
359 memcpy( *u1, data + end_data, len_u1 ); end_data += len_u1;
362 size_t len_u2 = I2 * R2 *
sizeof( T_coeff );
363 memcpy( *u2, data + end_data, len_u2 ); end_data += len_u2;
366 size_t len_u3 = I3 * R3 *
sizeof( T_coeff );
367 memcpy( *u3, data + end_data, len_u3 ); end_data += len_u3;
370 size_t len_core_el = 1;
373 unsigned char core_el;
374 for (
size_t r2 = 0; r2 < R2; ++r2 ) {
375 for (
size_t r3 = 0; r3 < R3; ++r3 ) {
376 for (
size_t r1 = 0; r1 < R1; ++r1 ) {
377 memcpy( &core_el, data + end_data, len_core_el );
378 signs.at( r1, r2, r3 ) = (core_el & 0x80)/128;
379 core.at( r1, r2, r3 ) = core_el & 0x7f ;
389 tuck3_data_.set_u1( *u1 );
390 tuck3_data_.set_u2( *u2 );
391 tuck3_data_.set_u3( *u3 );
392 tuck3_data_.set_core( core );
393 tuck3_data_.set_core_signs( signs );
396 tuck3_data_.dequantize_basis_matrices( u_min, u_max, u_min, u_max, u_min, u_max );
397 tuck3_data_.dequantize_core( core_min, core_max );
404 #undef VMML_TEMPLATE_STRING
405 #undef VMML_TEMPLATE_CLASSNAME