44 #ifndef __VMML__T3_HOSVD__HPP__
45 #define __VMML__T3_HOSVD__HPP__
47 #include <vmmlib/tensor3.hpp>
48 #include <vmmlib/lapack_svd.hpp>
49 #include <vmmlib/lapack_sym_eigs.hpp>
50 #include <vmmlib/blas_dgemm.hpp>
51 #include <vmmlib/blas_daxpy.hpp>
62 template<
size_t R1,
size_t R2,
size_t R3,
size_t I1,
size_t I2,
size_t I3,
typename T =
float >
91 static void apply_mode1(
const t3_type& data_,
u1_type& u1_, hosvd_method method_ = eigs_e );
92 static void apply_mode2(
const t3_type& data_,
u2_type& u2_, hosvd_method method_ = eigs_e );
93 static void apply_mode3(
const t3_type& data_,
u3_type& u3_, hosvd_method method_ = eigs_e );
102 template<
size_t M,
size_t N,
size_t R >
110 template<
size_t N,
size_t R >
119 #define VMML_TEMPLATE_STRING template< size_t R1, size_t R2, size_t R3, size_t I1, size_t I2, size_t I3, typename T >
120 #define VMML_TEMPLATE_CLASSNAME t3_hosvd< R1, R2, R3, I1, I2, I3, T >
126 VMML_TEMPLATE_CLASSNAME::hosvd(
const t3_type& data_, u1_type& u1_, u2_type& u2_, u3_type& u3_ )
128 svd_mode1( data_, u1_ );
129 svd_mode2( data_, u2_ );
130 svd_mode3( data_, u3_ );
135 VMML_TEMPLATE_CLASSNAME::hoeigs(
const t3_type& data_, u1_type& u1_, u2_type& u2_, u3_type& u3_ )
137 eigs_mode1( data_, u1_ );
138 eigs_mode2( data_, u2_ );
139 eigs_mode3( data_, u3_ );
144 VMML_TEMPLATE_CLASSNAME::apply_all(
const t3_type& data_, u1_type& u1_, u2_type& u2_, u3_type& u3_, hosvd_method method_ )
146 apply_mode1( data_, u1_, method_ );
147 apply_mode2( data_, u2_, method_ );
148 apply_mode3( data_, u3_, method_ );
153 VMML_TEMPLATE_CLASSNAME::apply_mode1(
const t3_type& data_, u1_type& u1_, hosvd_method method_ )
158 eigs_mode1( data_, u1_ );
161 svd_mode1( data_, u1_ );
164 eigs_mode1( data_, u1_ );
171 VMML_TEMPLATE_CLASSNAME::apply_mode2(
const t3_type& data_, u2_type& u2_, hosvd_method method_ )
176 eigs_mode2( data_, u2_ );
179 svd_mode2( data_, u2_ );
182 eigs_mode2( data_, u2_ );
190 VMML_TEMPLATE_CLASSNAME::apply_mode3(
const t3_type& data_, u3_type& u3_, hosvd_method method_ )
195 eigs_mode3( data_, u3_ );
198 svd_mode3( data_, u3_ );
201 eigs_mode3( data_, u3_ );
209 VMML_TEMPLATE_CLASSNAME::svd_mode1(
const t3_type& data_, u1_type& u1_ )
211 u1_unfolded_type* u =
new u1_unfolded_type;
212 data_.lateral_unfolding_bwd( *u );
214 get_svd_u_red( *u, u1_ );
221 VMML_TEMPLATE_CLASSNAME::svd_mode2(
const t3_type& data_, u2_type& u2_ )
223 u2_unfolded_type* u =
new u2_unfolded_type;
224 data_.frontal_unfolding_bwd( *u );
226 get_svd_u_red( *u, u2_ );
233 VMML_TEMPLATE_CLASSNAME::svd_mode3(
const t3_type& data_, u3_type& u3_ )
235 u3_unfolded_type* u =
new u3_unfolded_type;
236 data_.horizontal_unfolding_bwd( *u );
238 get_svd_u_red( *u, u3_ );
247 VMML_TEMPLATE_CLASSNAME::eigs_mode1(
const t3_type& data_, u1_type& u1_ )
250 u1_unfolded_type* m_lateral =
new u1_unfolded_type;
251 data_.lateral_unfolding_bwd( *m_lateral );
254 u1_cov_type* cov =
new u1_cov_type;
256 blas_dgemm< I1, I2*I3, I1, T>* blas_cov =
new blas_dgemm< I1, I2*I3, I1, T>;
257 blas_cov->compute( *m_lateral, *cov );
259 blas_daxpy< I1, T>* blas_cov =
new blas_daxpy< I1, T>;
260 blas_cov->compute_mmm( *m_lateral, *cov );
266 get_eigs_u_red( *cov, u1_ );
273 VMML_TEMPLATE_CLASSNAME::eigs_mode2(
const t3_type& data_, u2_type& u2_ )
276 u2_unfolded_type* m_frontal =
new u2_unfolded_type;
277 data_.frontal_unfolding_bwd( *m_frontal );
280 u2_cov_type* cov =
new u2_cov_type;
282 blas_dgemm< I2, I1*I3, I2, T>* blas_cov =
new blas_dgemm< I2, I1*I3, I2, T>;
283 blas_cov->compute( *m_frontal, *cov );
285 blas_daxpy< I2, T>* blas_cov =
new blas_daxpy< I2, T>;
286 blas_cov->compute_mmm( *m_frontal, *cov );
293 get_eigs_u_red( *cov, u2_ );
300 VMML_TEMPLATE_CLASSNAME::eigs_mode3(
const t3_type& data_, u3_type& u3_)
303 u3_unfolded_type* m_horizontal =
new u3_unfolded_type;
304 data_.horizontal_unfolding_bwd( *m_horizontal );
307 u3_cov_type* cov =
new u3_cov_type;
309 blas_dgemm< I3, I1*I2, I3, T>* blas_cov =
new blas_dgemm< I3, I1*I2, I3, T>;
310 blas_cov->compute( *m_horizontal, *cov );
312 blas_daxpy< I3, T>* blas_cov =
new blas_daxpy< I3, T>;
313 blas_cov->compute_mmm( *m_horizontal, *cov );
320 get_eigs_u_red( *cov, u3_ );
330 template<
size_t N,
size_t R >
332 VMML_TEMPLATE_CLASSNAME::get_eigs_u_red(
const matrix< N, N, T >& data_, matrix< N, R, T >& u_ )
334 typedef matrix< N, N, T_svd > cov_matrix_type;
335 typedef vector< R, T_svd > eigval_type;
336 typedef matrix< N, R, T_svd > eigvec_type;
340 eigval_type* eigxvalues =
new eigval_type;
341 eigvec_type* eigxvectors =
new eigvec_type;
343 lapack_sym_eigs< N, T_svd > eigs;
344 cov_matrix_type* data =
new cov_matrix_type;
345 data->cast_from( data_ );
346 if( eigs.compute_x( *data, *eigxvectors, *eigxvalues) ) {
355 if (
sizeof( T ) != 4 ){
356 u_.cast_from( *eigxvectors );
372 template<
size_t M,
size_t N,
size_t R >
374 VMML_TEMPLATE_CLASSNAME::get_svd_u_red(
const matrix< M, N, T >& data_, matrix< M, R, T >& u_ )
376 typedef matrix< M, N, T_svd > svd_m_type;
378 typedef matrix< M, N, T > m_type;
379 typedef vector< N, T_svd > lambdas_type;
381 svd_m_type* u_double =
new svd_m_type;
382 u_double->cast_from( data_ );
385 m_type* u_out =
new m_type;
387 lambdas_type* lambdas =
new lambdas_type;
388 lapack_svd< M, N, T_svd >* svd =
new lapack_svd< M, N, T_svd >();
389 if( svd->compute_and_overwrite_input( *u_double, *lambdas )) {
396 if (
sizeof( T ) != 4 ){
397 u_out->cast_from( *u_double );
402 u_out->get_sub_matrix( u_ );
415 #undef VMML_TEMPLATE_STRING
416 #undef VMML_TEMPLATE_CLASSNAME