47 #ifndef __VMML__T3_HOPM__HPP__
48 #define __VMML__T3_HOPM__HPP__
50 #include <vmmlib/t3_hosvd.hpp>
51 #include <vmmlib/matrix_pseudoinverse.hpp>
52 #include <vmmlib/blas_dgemm.hpp>
53 #include <vmmlib/blas_dot.hpp>
54 #include <vmmlib/validator.hpp>
55 #include <vmmlib/t3_ttv.hpp>
56 #include <vmmlib/tensor_stats.hpp>
60 template<
size_t R,
size_t I1,
size_t I2,
size_t I3,
typename T =
float >
81 typedef typename lambda_type::iterator lvalue_iterator;
82 typedef typename lambda_type::const_iterator lvalue_const_iterator;
83 typedef std::pair< T, size_t > lambda_pair_type;
87 template<
typename T_init >
132 template<
size_t J,
size_t K,
size_t L >
145 inline bool operator()(
const lambda_pair_type& a,
const lambda_pair_type & b) {
146 return fabs(a.first) > fabs(b.first);
154 #define VMML_TEMPLATE_STRING template< size_t R, size_t I1, size_t I2, size_t I3, typename T >
155 #define VMML_TEMPLATE_CLASSNAME t3_hopm< R, I1, I2, I3, T >
158 template<
typename T_init>
160 VMML_TEMPLATE_CLASSNAME::als(
const t3_type& data_,
161 u1_type& u1_, u2_type& u2_, u3_type& u3_,
162 lambda_type& lambdas_,
164 const size_t max_iterations_,
const float tolerance_) {
167 t3_type* approximated_data =
new t3_type;
168 t3_type* residual_data =
new t3_type;
169 residual_data->zero();
173 T fitchange, norm2, innerprod, fitold, normresidual;
174 if (tolerance_ > 0) {
175 max_f_norm = data_.frobenius_norm();
180 norm2 = innerprod = 0;
187 init(data_, u2_, u3_);
189 assert(validator::is_valid(u2_) && validator::is_valid(u3_));
190 assert(validator::is_valid(lambdas_));
193 std::ostringstream convert;
195 std::cout << convert.str() +
"-R rank CP ALS: HOPM (for tensor3) " << std::endl;
199 while (i < max_iterations_ && (tolerance_ < 0 || fitchange >= tolerance_))
206 optimize_mode1(data_, u1_, u2_, u3_, lambdas_);
211 optimize_mode2(data_, u1_, u2_, u3_, lambdas_);
216 optimize_mode3(data_, u1_, u2_, u3_, lambdas_);
221 if (tolerance_ > 0) {
222 norm2 = norm_ktensor(u1_, u2_, u3_, lambdas_);
225 for(
size_t j = 0; j < lambdas_.size(); ++j)
227 matrix<I2, I3, T> res1;
228 t3_ttv::multiply_first_mode(data_, u1_.get_column(j), res1);
229 vector<I2, T> res2 = res1 * u3_.get_column(j);
230 val2 += lambdas_.at(j) * (res2.dot(u2_.get_column(j)));
232 innerprod = 2 * val2;
234 normresidual = sqrt(max_f_norm * max_f_norm + norm2 * norm2 - innerprod);
235 fit = 1 - (normresidual / max_f_norm);
236 fitchange = fabs(fitold - fit);
239 std::cout <<
"iteration '" << i
241 <<
", fitdelta: " << fit - fitold
242 <<
", normresidual: " << normresidual;
243 if (fit - fitold < 0) std::cout <<
" *fit is worsening*";
244 std::cout << std::endl;
252 reconstruct(*approximated_data, u1_, u2_, u3_, lambdas_);
254 *residual_data = data_ - *approximated_data;
255 normresidual = residual_data->frobenius_norm();
256 std::cerr << myTimer.get_seconds() <<
" ";
261 result.set_n_iterations(i);
264 sort_dec(u1_, u2_, u3_, lambdas_);
266 delete residual_data;
267 delete approximated_data;
273 VMML_TEMPLATE_CLASSNAME::optimize_mode1(
const t3_type& data_, u1_type& u1_,
const u2_type& u2_,
const u3_type& u3_, lambda_type& lambdas_) {
274 u1_unfolded_type* unfolding =
new u1_unfolded_type;
276 data_.frontal_unfolding_fwd(*unfolding);
278 assert(validator::is_valid(u2_) && validator::is_valid(u3_));
280 optimize(*unfolding, u1_, u2_, u3_, lambdas_);
287 VMML_TEMPLATE_CLASSNAME::optimize_mode2(
const t3_type& data_,
const u1_type& u1_, u2_type& u2_,
const u3_type& u3_, lambda_type& lambdas_) {
288 u2_unfolded_type* unfolding =
new u2_unfolded_type;
289 data_.frontal_unfolding_bwd(*unfolding);
292 assert(validator::is_valid(u1_) && validator::is_valid(u3_));
294 optimize(*unfolding, u2_, u1_, u3_, lambdas_);
301 VMML_TEMPLATE_CLASSNAME::optimize_mode3(
const t3_type& data_,
const u1_type& u1_,
const u2_type& u2_, u3_type& u3_, lambda_type& lambdas_) {
302 u3_unfolded_type* unfolding =
new u3_unfolded_type;
304 data_.lateral_unfolding_fwd(*unfolding);
306 assert(validator::is_valid(u1_) && validator::is_valid(u2_));
308 optimize(*unfolding, u3_, u1_, u2_, lambdas_);
314 template<
size_t J,
size_t K,
size_t L >
316 VMML_TEMPLATE_CLASSNAME::optimize(
317 const matrix< J, K*L, T >& unfolding_,
318 matrix< J, R, T >& uj_,
319 const matrix< K, R, T >& uk_,
const matrix< L, R, T >& ul_,
320 vector< R, T>& lambdas_
323 typedef matrix< K*L, R, T > krp_matrix_type;
324 krp_matrix_type* krp_prod =
new krp_matrix_type;
325 assert(validator::is_valid(uk_) && validator::is_valid(ul_));
327 ul_.khatri_rao_product(uk_, *krp_prod);
329 matrix< J, R, T >* u_new =
new matrix< J, R, T >;
331 blas_dgemm< J, K*L, R, T> blas_dgemm1;
332 blas_dgemm1.compute(unfolding_, *krp_prod, *u_new);
335 m_r2_type* uk_r =
new m_r2_type;
336 m_r2_type* ul_r =
new m_r2_type;
338 blas_dgemm< R, K, R, T> blas_dgemm2;
339 blas_dgemm2.compute_t(uk_, *uk_r);
340 assert(validator::is_valid(*uk_r));
342 blas_dgemm< R, L, R, T> blas_dgemm3;
343 blas_dgemm3.compute_t(ul_, *ul_r);
344 assert(validator::is_valid(*ul_r));
346 uk_r->multiply_piecewise(*ul_r);
347 assert(validator::is_valid(*uk_r));
349 m_r2_type* pinv_t =
new m_r2_type;
350 compute_pseudoinverse< m_r2_type > compute_pinv;
352 compute_pinv(*uk_r, *pinv_t);
354 blas_dgemm< J, R, R, T> blas_dgemm4;
355 blas_dgemm4.compute_bt(*u_new, *pinv_t, uj_);
356 assert(validator::is_valid(uj_));
359 u_new->multiply_piecewise(*u_new);
360 u_new->columnwise_sum(lambdas_);
362 assert(validator::is_valid(lambdas_));
364 lambdas_.sqrt_elementwise();
365 lambda_type* tmp =
new lambda_type;
367 tmp->reciprocal_safe();
369 assert(validator::is_valid(*tmp));
371 m_r2_type* diag_lambdas =
new m_r2_type;
372 diag_lambdas->diag(*tmp);
374 matrix< J, R, T >* tmp_uj =
new matrix< J, R, T > (uj_);
375 blas_dgemm4.compute(*tmp_uj, *diag_lambdas, uj_);
377 assert(validator::is_valid(uj_));
392 VMML_TEMPLATE_CLASSNAME::reconstruct(t3_type& data_,
const u1_type& u1_,
const u2_type& u2_,
const u3_type& u3_,
const lambda_type& lambdas_) {
393 u1_inv_type* u1_t =
new u1_inv_type;
394 u2_inv_type* u2_t =
new u2_inv_type;
395 u3_inv_type* u3_t =
new u3_inv_type;
396 typedef matrix< R, I2 * I3, T > m_temp_type;
397 m_temp_type* temp =
new m_temp_type;
399 u1_.transpose_to(*u1_t);
400 u2_.transpose_to(*u2_t);
401 u3_.transpose_to(*u3_t);
403 data_.reconstruct_CP(lambdas_, *u1_t, *u2_t, *u3_t, *temp);
413 VMML_TEMPLATE_CLASSNAME::norm_ktensor(
const u1_type& u1_,
const u2_type& u2_,
const u3_type& u3_,
const lambda_type& lambdas_) {
414 m_r2_type* coeff2_matrix =
new m_r2_type;
415 m_r2_type* cov_u1 =
new m_r2_type;
416 m_r2_type* cov_u2 =
new m_r2_type;
417 m_r2_type* cov_u3 =
new m_r2_type;
419 blas_dgemm< R, 1, R, T >* blas_l2 =
new blas_dgemm< R, 1, R, T>;
420 blas_l2->compute_vv_outer(lambdas_, lambdas_, *coeff2_matrix);
423 blas_dgemm< R, I1, R, T >* blas_u1cov =
new blas_dgemm< R, I1, R, T>;
424 blas_u1cov->compute_t(u1_, *cov_u1);
427 blas_dgemm< R, I2, R, T >* blas_u2cov =
new blas_dgemm< R, I2, R, T>;
428 blas_u2cov->compute_t(u2_, *cov_u2);
431 blas_dgemm< R, I3, R, T >* blas_u3cov =
new blas_dgemm< R, I3, R, T>;
432 blas_u3cov->compute_t(u3_, *cov_u3);
435 coeff2_matrix->multiply_piecewise(*cov_u1);
436 coeff2_matrix->multiply_piecewise(*cov_u2);
437 coeff2_matrix->multiply_piecewise(*cov_u3);
439 double nrm = coeff2_matrix->sum_elements();
441 delete coeff2_matrix;
451 VMML_TEMPLATE_CLASSNAME::sort_dec(u1_type& u1_, u2_type& u2_, u3_type& u3_, lambda_type& lambdas_) {
453 u1_type *orig_u1 =
new u1_type(u1_);
454 u2_type *orig_u2 =
new u2_type(u2_);
455 u3_type *orig_u3 =
new u3_type(u3_);
456 lambda_type sorted_lvalues;
459 std::vector< lambda_pair_type > lambda_permut;
461 lvalue_const_iterator it = lambdas_.begin(), it_end = lambdas_.end();
463 for (; it != it_end; ++it, ++counter) {
464 lambda_permut.push_back(lambda_pair_type(*it, counter));
468 lambda_permut.begin(),
474 typename std::vector< lambda_pair_type >::const_iterator it2 = lambda_permut.begin(), it2_end = lambda_permut.end();
475 lvalue_iterator lvalues_it = lambdas_.begin();
476 for (counter = 0; it2 != it2_end; ++it2, ++counter, ++lvalues_it) {
477 *lvalues_it = it2->first;
478 u1_.set_column(counter, orig_u1->get_column(it2->second));
479 u2_.set_column(counter, orig_u2->get_column(it2->second));
480 u3_.set_column(counter, orig_u3->get_column(it2->second));
492 #undef VMML_TEMPLATE_STRING
493 #undef VMML_TEMPLATE_CLASSNAME