42 #ifndef __VMML__T4_TTM__HPP__
43 #define __VMML__T4_TTM__HPP__
45 #include <vmmlib/tensor4.hpp>
46 #include <vmmlib/t3_ttm.hpp>
47 #include <vmmlib/blas_dgemm.hpp>
48 #ifdef VMMLIB_USE_OPENMP
59 template<
size_t I1,
size_t I2,
size_t I3,
size_t I4,
size_t J1,
size_t J2,
size_t J3,
size_t J4,
typename T >
68 template<
size_t I1,
size_t J1,
size_t J2,
size_t J3,
size_t J4,
typename T >
69 static void mode1_multiply_fwd(
const tensor4< J1, J2, J3, J4, T >& t4_in_,
const matrix< I1, J1, T >& in_slice_,
tensor4< I1, J2, J3, J4, T >& t4_res_ );
71 template<
size_t I2,
size_t J1,
size_t J2,
size_t J3,
size_t J4,
typename T >
72 static void mode2_multiply_fwd(
const tensor4< J1, J2, J3, J4, T >& t4_in_,
const matrix< I2, J2, T >& in_slice_,
tensor4< J1, I2, J3, J4, T >& t4_res_ );
74 template<
size_t I3,
size_t J1,
size_t J2,
size_t J3,
size_t J4,
typename T >
75 static void mode3_multiply_fwd(
const tensor4< J1, J2, J3, J4, T >& t4_in_,
const matrix< I3, J3, T >& in_slice_,
tensor4< J1, J2, I3, J4, T >& t4_res_ );
77 template<
size_t I4,
size_t J1,
size_t J2,
size_t J3,
size_t J4,
typename T >
78 static void mode4_multiply_fwd(
const tensor4< J1, J2, J3, J4, T >& t4_in_,
const matrix< I4, J4, T >& in_slice_,
tensor4< J1, J2, J3, I4, T >& t4_res_ );
84 #define VMML_TEMPLATE_CLASSNAME t4_ttm
86 template<
size_t I1,
size_t I2,
size_t I3,
size_t I4,
size_t J1,
size_t J2,
size_t J3,
size_t J4,
typename T >
100 mode1_multiply_fwd( t4_in_, U1, t4_result_1 );
101 mode2_multiply_fwd( t4_result_1, U2, t4_result_2 );
102 mode3_multiply_fwd( t4_result_2, U3, t4_result_3 );
103 mode4_multiply_fwd( t4_result_3, U4, t4_res_ );
108 template<
size_t I1,
size_t J1,
size_t J2,
size_t J3,
size_t J4,
typename T >
110 VMML_TEMPLATE_CLASSNAME::mode1_multiply_fwd(
const tensor4< J1, J2, J3, J4, T >& t4_in_,
const matrix< I1, J1, T >& in_slice_, tensor4< I1, J2, J3, J4, T >& t4_res_ ) {
111 for (
size_t l = 0; l < J4; ++l) {
112 tensor3< J1, J2, J3, T > temp_input = t4_in_.get_tensor3(l);
113 tensor3< I1, J2, J3, T > temp_output = t4_res_.get_tensor3(l);
114 t3_ttm::multiply_frontal_fwd(temp_input, in_slice_, temp_output);
115 t4_res_.set_tensor3(l,temp_output);
119 template<
size_t I2,
size_t J1,
size_t J2,
size_t J3,
size_t J4,
typename T >
121 VMML_TEMPLATE_CLASSNAME::mode2_multiply_fwd(
const tensor4< J1, J2, J3, J4, T >& t4_in_,
const matrix< I2, J2, T >& in_slice_, tensor4< J1, I2, J3, J4, T >& t4_res_ ) {
122 for (
size_t l = 0; l < J4; ++l) {
123 tensor3< J1, J2, J3, T > temp_input = t4_in_.get_tensor3(l);
124 tensor3< J1, I2, J3, T > temp_output = t4_res_.get_tensor3(l);
125 t3_ttm::multiply_horizontal_fwd(temp_input, in_slice_, temp_output);
126 t4_res_.set_tensor3(l,temp_output);
130 template<
size_t I3,
size_t J1,
size_t J2,
size_t J3,
size_t J4,
typename T >
132 VMML_TEMPLATE_CLASSNAME::mode3_multiply_fwd(
const tensor4< J1, J2, J3, J4, T >& t4_in_,
const matrix< I3, J3, T >& in_slice_, tensor4< J1, J2, I3, J4, T >& t4_res_ ) {
133 for (
size_t l = 0; l < J4; ++l) {
134 tensor3< J1, J2, J3, T > temp_input = t4_in_.get_tensor3(l);
135 tensor3< J1, J2, I3, T > temp_output = t4_res_.get_tensor3(l);
136 t3_ttm::multiply_lateral_fwd(temp_input, in_slice_, temp_output);
137 t4_res_.set_tensor3(l,temp_output);
141 template<
size_t I4,
size_t J1,
size_t J2,
size_t J3,
size_t J4,
typename T >
143 VMML_TEMPLATE_CLASSNAME::mode4_multiply_fwd(
const tensor4< J1, J2, J3, J4, T >& t4_in_,
const matrix< I4, J4, T >& in_slice_, tensor4< J1, J2, J3, I4, T >& t4_res_ ) {
145 for (
size_t i = 0; i < J1; ++i) {
146 for (
size_t j = 0; j < J2; ++j) {
147 for (
size_t k = 0; k < J3; ++k) {
148 for (
size_t newL = 0; newL < I4; ++newL) {
150 for (
size_t l = 0; l < J4; ++l) {
151 sum += t4_in_.at(i,j,k,l)*in_slice_.at(newL,l);
153 t4_res_.at(i,j,k,newL) = sum;
160 #undef VMML_TEMPLATE_CLASSNAME