42 #ifndef __VMML__T3_TTM__HPP__
43 #define __VMML__T3_TTM__HPP__
45 #include <vmmlib/tensor3.hpp>
46 #include <vmmlib/blas_dgemm.hpp>
47 #ifdef VMMLIB_USE_OPENMP
61 template<
size_t I1,
size_t I2,
size_t I3,
size_t J1,
size_t J2,
size_t J3,
typename T >
62 static void full_tensor3_matrix_multiplication(
const tensor3< J1, J2, J3, T >& t3_in_,
const matrix< I1, J1, T >& U1,
const matrix< I2, J2, T >& U2,
const matrix< I3, J3, T >& U3,
tensor3< I1, I2, I3, T >& t3_res_ );
64 template<
size_t I1,
size_t I2,
size_t I3,
size_t J1,
size_t J2,
size_t J3,
typename T >
65 static void full_tensor3_matrix_kronecker_mult(
const tensor3< J1, J2, J3, T >& t3_in_,
const matrix< I1, J1, T >& U1,
const matrix< I2, J2, T >& U2,
const matrix< I3, J3, T >& U3,
tensor3< I1, I2, I3, T >& t3_res_ );
68 template<
size_t I3,
size_t J1,
size_t J2,
size_t J3,
typename T >
69 static void multiply_horizontal_bwd(
const tensor3< J1, J2, J3, T >& t3_in_,
const matrix< I3, J3, T >& in_slice_,
tensor3< J1, J2, I3, T >& t3_res_ );
71 template<
size_t I1,
size_t J1,
size_t J2,
size_t J3,
typename T >
72 static void multiply_lateral_bwd(
const tensor3< J1, J2, J3, T >& t3_in_,
const matrix< I1, J1, T >& in_slice_,
tensor3< I1, J2, J3, T >& t3_res_ );
74 template<
size_t I2,
size_t J1,
size_t J2,
size_t J3,
typename T >
75 static void multiply_frontal_bwd(
const tensor3< J1, J2, J3, T >& t3_in_,
const matrix< I2, J2, T >& in_slice_,
tensor3< J1, I2, J3, T >& t3_res_ );
80 template<
size_t I1,
size_t I2,
size_t I3,
size_t J1,
size_t J2,
size_t J3 >
81 static void full_tensor3_matrix_multiplication(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
const matrix< I1, J1, T_blas >& U1,
const matrix< I2, J2, T_blas >& U2,
const matrix< I3, J3, T_blas >& U3,
tensor3< I1, I2, I3, T_blas >& t3_res_ );
84 template<
size_t I3,
size_t J1,
size_t J2,
size_t J3 >
85 static void multiply_horizontal_bwd(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
const matrix< I3, J3, T_blas >& in_slice_,
tensor3< J1, J2, I3, T_blas >& t3_res_ );
87 template<
size_t I1,
size_t J1,
size_t J2,
size_t J3 >
88 static void multiply_lateral_bwd(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
const matrix< I1, J1, T_blas >& in_slice_,
tensor3< I1, J2, J3, T_blas >& t3_res_ );
90 template<
size_t I2,
size_t J1,
size_t J2,
size_t J3 >
91 static void multiply_frontal_bwd(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
const matrix< I2, J2, T_blas >& in_slice_,
tensor3< J1, I2, J3, T_blas >& t3_res_ );
94 template<
size_t I2,
size_t J1,
size_t J2,
size_t J3 >
95 static void multiply_horizontal_fwd(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
const matrix< I2, J2, T_blas >& in_slice_,
tensor3< J1, I2, J3, T_blas >& t3_res_ );
97 template<
size_t I3,
size_t J1,
size_t J2,
size_t J3 >
98 static void multiply_lateral_fwd(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
const matrix< I3, J3, T_blas >& in_slice_,
tensor3< J1, J2, I3, T_blas >& t3_res_ );
100 template<
size_t I1,
size_t J1,
size_t J2,
size_t J3 >
101 static void multiply_frontal_fwd(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
const matrix< I1, J1, T_blas >& in_slice_,
tensor3< I1, J2, J3, T_blas >& t3_res_ );
103 template<
size_t I2,
size_t J1,
size_t J2,
size_t J3,
typename T >
104 static void multiply_horizontal_fwd(
const tensor3< J1, J2, J3, T >& t3_in_,
const matrix< I2, J2, T >& in_slice_,
tensor3< J1, I2, J3, T >& t3_res_ );
106 template<
size_t I3,
size_t J1,
size_t J2,
size_t J3,
typename T >
107 static void multiply_lateral_fwd(
const tensor3< J1, J2, J3, T >& t3_in_,
const matrix< I3, J3, T >& in_slice_,
tensor3< J1, J2, I3, T >& t3_res_ );
109 template<
size_t I1,
size_t J1,
size_t J2,
size_t J3,
typename T >
110 static void multiply_frontal_fwd(
const tensor3< J1, J2, J3, T >& t3_in_,
const matrix< I1, J1, T >& in_slice_,
tensor3< I1, J2, J3, T >& t3_res_ );
117 #define VMML_TEMPLATE_CLASSNAME t3_ttm
123 template<
size_t I1,
size_t I2,
size_t I3,
size_t J1,
size_t J2,
size_t J3,
typename T >
140 multiply_lateral_bwd( t3_in_, U1, t3_result_1 );
141 multiply_frontal_bwd( t3_result_1, U2, t3_result_2 );
142 multiply_horizontal_bwd( t3_result_2, U3, t3_res_ );
146 multiply_frontal_fwd( t3_in_, U1, t3_result_1 );
147 multiply_horizontal_fwd( t3_result_1, U2, t3_result_2 );
148 multiply_lateral_fwd( t3_result_2, U3, t3_res_ );
152 template<
size_t I1,
size_t I2,
size_t I3,
size_t J1,
size_t J2,
size_t J3,
typename T >
154 VMML_TEMPLATE_CLASSNAME::full_tensor3_matrix_kronecker_mult(
const tensor3< J1, J2, J3, T >& t3_in_,
155 const matrix< I1, J1, T >& U1,
156 const matrix< I2, J2, T >& U2,
157 const matrix< I3, J3, T >& U3,
158 tensor3< I1, I2, I3, T >& t3_res_
163 matrix< J1, J2*J3, T>* core_unfolded =
new matrix< J1, J2*J3, T>;
164 t3_in_.lateral_unfolding_bwd( *core_unfolded );
165 matrix< I1, J2*J3, T>* tmp1 =
new matrix< I1, J2*J3, T>;
166 tmp1->multiply( U1, *core_unfolded );
168 matrix< I2*I3, J2*J3, T>* kron_prod =
new matrix< I2*I3, J2*J3, T>;
169 U2.kronecker_product( U3, *kron_prod );
171 matrix< I1, I2*I3, T>* res_unfolded =
new matrix< I1, I2*I3, T>;
172 res_unfolded->multiply( *tmp1, transpose(*kron_prod) );
178 for(
size_t i = 0; i < (I2*I3); ++i, ++i2 )
185 t3_res_.set_column( i2, i3, res_unfolded->get_column(i));
188 delete core_unfolded;
199 template<
size_t I3,
size_t J1,
size_t J2,
size_t J3,
typename T >
201 VMML_TEMPLATE_CLASSNAME::multiply_horizontal_bwd(
const tensor3< J1, J2, J3, T >& t3_in_,
202 const matrix< I3, J3, T >& in_slice_,
203 tensor3< J1, J2, I3, T >& t3_res_ )
205 typedef matrix< I3, J3, T_blas > slice_t;
207 tensor3< J1, J2, J3, T_blas > t3_in( t3_in_ );
208 slice_t* in_slice =
new slice_t( in_slice_ );
209 tensor3< J1, J2, I3, T_blas > t3_res; t3_res.zero();
211 multiply_horizontal_bwd( t3_in, *in_slice, t3_res );
212 t3_res_.cast_from( t3_res );
218 template<
size_t I1,
size_t J1,
size_t J2,
size_t J3,
typename T >
220 VMML_TEMPLATE_CLASSNAME::multiply_lateral_bwd(
const tensor3< J1, J2, J3, T >& t3_in_,
221 const matrix< I1, J1, T >& in_slice_,
222 tensor3< I1, J2, J3, T >& t3_res_ )
224 typedef matrix< I1, J1, T_blas > slice_t;
226 tensor3< J1, J2, J3, T_blas > t3_in( t3_in_ );
227 slice_t* in_slice =
new slice_t( in_slice_ );
228 tensor3< I1, J2, J3, T_blas > t3_res; t3_res.zero();
230 multiply_lateral_bwd( t3_in, *in_slice, t3_res );
231 t3_res_.cast_from( t3_res );
238 template<
size_t I2,
size_t J1,
size_t J2,
size_t J3,
typename T >
240 VMML_TEMPLATE_CLASSNAME::multiply_frontal_bwd(
const tensor3< J1, J2, J3, T >& t3_in_,
241 const matrix< I2, J2, T >& in_slice_,
242 tensor3< J1, I2, J3, T >& t3_res_ )
244 typedef matrix< I2, J2, T_blas > slice_t;
246 tensor3< J1, J2, J3, T_blas > t3_in( t3_in_ );
247 slice_t* in_slice =
new slice_t( in_slice_ );
248 tensor3< J1, I2, J3, T_blas > t3_res; t3_res.zero();
250 multiply_frontal_bwd( t3_in, *in_slice, t3_res );
251 t3_res_.cast_from( t3_res );
258 template<
size_t I2,
size_t J1,
size_t J2,
size_t J3,
typename T >
260 VMML_TEMPLATE_CLASSNAME::multiply_horizontal_fwd(
const tensor3< J1, J2, J3, T >& t3_in_,
261 const matrix< I2, J2, T >& in_slice_,
262 tensor3< J1, I2, J3, T >& t3_res_ )
264 typedef matrix< I2, J2, T_blas > slice_t;
266 tensor3< J1, J2, J3, T_blas > t3_in( t3_in_ );
267 slice_t* in_slice =
new slice_t( in_slice_ );
268 tensor3< J1, I2, J3, T_blas > t3_res; t3_res.zero();
270 multiply_horizontal_fwd( t3_in, *in_slice, t3_res );
271 t3_res_.cast_from( t3_res );
277 template<
size_t I3,
size_t J1,
size_t J2,
size_t J3,
typename T >
279 VMML_TEMPLATE_CLASSNAME::multiply_lateral_fwd(
const tensor3< J1, J2, J3, T >& t3_in_,
280 const matrix< I3, J3, T >& in_slice_,
281 tensor3< J1, J2, I3, T >& t3_res_ )
283 typedef matrix< I3, J3, T_blas > slice_t;
285 tensor3< J1, J2, J3, T_blas > t3_in( t3_in_ );
286 slice_t* in_slice =
new slice_t( in_slice_ );
287 tensor3< J1, J2, I3, T_blas > t3_res; t3_res.zero();
289 multiply_lateral_fwd( t3_in, *in_slice, t3_res );
290 t3_res_.cast_from( t3_res );
297 template<
size_t I1,
size_t J1,
size_t J2,
size_t J3,
typename T >
299 VMML_TEMPLATE_CLASSNAME::multiply_frontal_fwd(
const tensor3< J1, J2, J3, T >& t3_in_,
300 const matrix< I1, J1, T >& in_slice_,
301 tensor3< I1, J2, J3, T >& t3_res_ )
303 typedef matrix< I1, J1, T_blas > slice_t;
305 tensor3< J1, J2, J3, T_blas > t3_in( t3_in_ );
306 slice_t* in_slice =
new slice_t( in_slice_ );
307 tensor3< I1, J2, J3, T_blas > t3_res; t3_res.zero();
309 multiply_frontal_fwd( t3_in, *in_slice, t3_res );
310 t3_res_.cast_from( t3_res );
316 template<
size_t I1,
size_t I2,
size_t I3,
size_t J1,
size_t J2,
size_t J3 >
318 VMML_TEMPLATE_CLASSNAME::full_tensor3_matrix_multiplication(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
319 const matrix< I1, J1, T_blas >& U1,
320 const matrix< I2, J2, T_blas >& U2,
321 const matrix< I3, J3, T_blas >& U3,
322 tensor3< I1, I2, I3, T_blas >& t3_res_
325 tensor3< I1, J2, J3, T_blas > t3_result_1;
326 tensor3< I1, I2, J3, T_blas > t3_result_2;
331 multiply_lateral_bwd( t3_in_, U1, t3_result_1 );
332 multiply_frontal_bwd( t3_result_1, U2, t3_result_2 );
333 multiply_horizontal_bwd( t3_result_2, U3, t3_res_ );
337 multiply_frontal_fwd( t3_in_, U1, t3_result_1 );
338 multiply_horizontal_fwd( t3_result_1, U2, t3_result_2 );
339 multiply_lateral_fwd( t3_result_2, U3, t3_res_ );
346 template<
size_t I3,
size_t J1,
size_t J2,
size_t J3 >
348 VMML_TEMPLATE_CLASSNAME::multiply_horizontal_bwd(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
349 const matrix< I3, J3, T_blas >& in_slice_,
350 tensor3< J1, J2, I3, T_blas >& t3_res_ )
352 typedef matrix< J3, J2, T_blas > slice_t;
353 typedef matrix< I3, J2, T_blas > slice_new_t;
354 typedef blas_dgemm< I3, J3, J2, T_blas > blas_t;
356 #pragma omp parallel for
357 for (
int i1 = 0; i1 < (int)J1; ++i1 )
359 slice_t* slice =
new slice_t;
360 slice_new_t* slice_new =
new slice_new_t;
362 blas_t* multiplier =
new blas_t;
363 t3_in_.get_horizontal_slice_bwd( i1, *slice );
365 multiplier->compute( in_slice_, *slice, *slice_new );
367 t3_res_.set_horizontal_slice_bwd( i1, *slice_new );
376 template<
size_t I1,
size_t J1,
size_t J2,
size_t J3 >
378 VMML_TEMPLATE_CLASSNAME::multiply_lateral_bwd(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
379 const matrix< I1, J1, T_blas >& in_slice_,
380 tensor3< I1, J2, J3, T_blas >& t3_res_ )
382 typedef matrix< J1, J3, T_blas > slice_t;
383 typedef matrix< I1, J3, T_blas > slice_new_t;
384 typedef blas_dgemm< I1, J1, J3, T_blas > blas_t;
386 #pragma omp parallel for
387 for (
int i2 = 0; i2 < (int)J2; ++i2 )
389 slice_t* slice =
new slice_t;
390 slice_new_t* slice_new =
new slice_new_t;
392 blas_t* multiplier =
new blas_t;
393 t3_in_.get_lateral_slice_bwd( i2, *slice );
395 multiplier->compute( in_slice_, *slice, *slice_new );
397 t3_res_.set_lateral_slice_bwd( i2, *slice_new );
407 template<
size_t I2,
size_t J1,
size_t J2,
size_t J3 >
409 VMML_TEMPLATE_CLASSNAME::multiply_frontal_bwd(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
410 const matrix< I2, J2, T_blas >& in_slice_,
411 tensor3< J1, I2, J3, T_blas >& t3_res_ )
413 typedef matrix< J2, J1, T_blas > slice_t;
414 typedef matrix< I2, J1, T_blas > slice_new_t;
415 typedef blas_dgemm< I2, J2, J1, T_blas > blas_t;
417 #pragma omp parallel for
418 for (
int i3 = 0; i3 < (int)J3; ++i3 )
420 slice_t* slice =
new slice_t;
421 slice_new_t* slice_new =
new slice_new_t;
423 blas_t* multiplier =
new blas_t;
424 t3_in_.get_frontal_slice_bwd( i3, *slice );
426 multiplier->compute( in_slice_, *slice, *slice_new );
428 t3_res_.set_frontal_slice_bwd( i3, *slice_new );
440 template<
size_t I2,
size_t J1,
size_t J2,
size_t J3 >
442 VMML_TEMPLATE_CLASSNAME::multiply_horizontal_fwd(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
443 const matrix< I2, J2, T_blas >& in_slice_,
444 tensor3< J1, I2, J3, T_blas >& t3_res_ )
446 typedef matrix< J2, J3, T_blas > slice_t;
447 typedef matrix< I2, J3, T_blas > slice_new_t;
448 typedef blas_dgemm< I2, J2, J3, T_blas > blas_t;
450 #pragma omp parallel for
451 for (
int i1 = 0; i1 < (int)J1; ++i1 )
453 slice_t* slice =
new slice_t;
454 slice_new_t* slice_new =
new slice_new_t;
456 blas_t* multiplier =
new blas_t;
457 t3_in_.get_horizontal_slice_fwd( i1, *slice );
459 multiplier->compute( in_slice_, *slice, *slice_new );
461 t3_res_.set_horizontal_slice_fwd( i1, *slice_new );
470 template<
size_t I3,
size_t J1,
size_t J2,
size_t J3 >
472 VMML_TEMPLATE_CLASSNAME::multiply_lateral_fwd(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
473 const matrix< I3, J3, T_blas >& in_slice_,
474 tensor3< J1, J2, I3, T_blas >& t3_res_ )
476 typedef matrix< J3, J1, T_blas > slice_t;
477 typedef matrix< I3, J1, T_blas > slice_new_t;
478 typedef blas_dgemm< I3, J3, J1, T_blas > blas_t;
480 #pragma omp parallel for
481 for (
int i2 = 0; i2 < (int)J2; ++i2 )
483 slice_t* slice =
new slice_t;
484 slice_new_t* slice_new =
new slice_new_t;
486 blas_t* multiplier =
new blas_t;
487 t3_in_.get_lateral_slice_fwd( i2, *slice );
489 multiplier->compute( in_slice_, *slice, *slice_new );
491 t3_res_.set_lateral_slice_fwd( i2, *slice_new );
501 template<
size_t I1,
size_t J1,
size_t J2,
size_t J3 >
503 VMML_TEMPLATE_CLASSNAME::multiply_frontal_fwd(
const tensor3< J1, J2, J3, T_blas >& t3_in_,
504 const matrix< I1, J1, T_blas >& in_slice_,
505 tensor3< I1, J2, J3, T_blas >& t3_res_ )
507 typedef matrix< J1, J2, T_blas > slice_t;
508 typedef matrix< I1, J2, T_blas > slice_new_t;
510 typedef blas_dgemm< I1, J1, J2, T_blas > blas_t;
512 #pragma omp parallel for
513 for (
int i3 = 0; i3 < (int)J3; ++i3 )
515 slice_t* slice =
new slice_t;
516 slice_new_t* slice_new =
new slice_new_t;
518 blas_t* multiplier =
new blas_t;
519 t3_in_.get_frontal_slice_fwd( i3, *slice );
521 multiplier->compute( in_slice_, *slice, *slice_new );
523 t3_res_.set_frontal_slice_fwd( i3, *slice_new );
534 #undef VMML_TEMPLATE_CLASSNAME