32 #ifndef __VMML__VMMLIB_BLAS_DGEMM__HPP__
33 #define __VMML__VMMLIB_BLAS_DGEMM__HPP__
36 #include <vmmlib/matrix.hpp>
37 #include <vmmlib/tensor3.hpp>
38 #include <vmmlib/exception.hpp>
39 #include <vmmlib/blas_includes.hpp>
40 #include <vmmlib/blas_types.hpp>
86 void cblas_dgemm(
enum CBLAS_ORDER Order,
enum CBLAS_TRANSPOSE TransA,
enum CBLAS_TRANSPOSE TransB,
87 blasint M, blasint N, blasint K,
88 double alpha,
double *A, blasint lda,
double *B, blasint ldb,
double beta,
double *C, blasint ldc);
92 template<
typename float_t >
96 CBLAS_TRANSPOSE trans_a;
97 CBLAS_TRANSPOSE trans_b;
110 friend std::ostream& operator << ( std::ostream& os,
114 <<
" (1)\torder " << p.order << std::endl
115 <<
" (2)\ttrans_a " << p.trans_a << std::endl
116 <<
" (3)\ttrans_b " << p.trans_b << std::endl
117 <<
" (4)\tm " << p.m << std::endl
118 <<
" (6)\tn " << p.n << std::endl
119 <<
" (5)\tk " << p.k << std::endl
120 <<
" (7)\talpha " << p.alpha << std::endl
121 <<
" (8)\ta " << p.a << std::endl
122 <<
" (9)\tlda " << p.lda << std::endl
123 <<
" (10)\tb " << p.b << std::endl
124 <<
" (11)\tldb " << p.ldb << std::endl
125 <<
" (12)\tbeta " << p.beta << std::endl
126 <<
" (13)\tc " << p.c << std::endl
127 <<
" (14)\tldc " << p.ldc << std::endl
136 template<
typename float_t >
140 VMMLIB_ERROR(
"not implemented for this type.", VMMLIB_HERE );
146 dgemm_call( dgemm_params< float >& p )
170 dgemm_call( dgemm_params< double >& p )
195 template<
size_t M,
size_t K,
size_t N,
typename float_t >
215 template<
size_t I2,
size_t I3 >
218 template<
size_t I2,
size_t I3 >
235 template<
size_t M,
size_t K,
size_t N,
typename float_t >
238 p.order = CblasColMajor;
239 p.trans_a = CblasNoTrans;
240 p.trans_b = CblasNoTrans;
256 template<
size_t M,
size_t K,
size_t N,
typename float_t >
258 blas_dgemm< M, K, N, float_t >::compute(
259 const matrix_left_t& A_,
260 const matrix_right_t& B_,
265 matrix_left_t* AA =
new matrix_left_t( A_ );
266 matrix_right_t* BB =
new matrix_right_t( B_ );
273 blas::dgemm_call< float_t >( p );
283 template<
size_t M,
size_t K,
size_t N,
typename float_t >
284 template<
size_t I2,
size_t I3 >
286 blas_dgemm< M, K, N, float_t >::compute(
287 const tensor3< M, I2, I3, float_t >& A_,
288 const matrix_right_t& B_,
293 tensor3< M, I2, I3, float_t > AA( A_ );
294 matrix_right_t* BB =
new matrix_right_t( B_ );
297 p.a = AA.get_array_ptr();
301 blas::dgemm_call< float_t >( p );
311 template<
size_t M,
size_t K,
size_t N,
typename float_t >
313 blas_dgemm< M, K, N, float_t >::compute(
const matrix_left_t& A_, matrix_out_t& C_ )
316 matrix_left_t* AA =
new matrix_left_t( A_ );
319 p.trans_b = CblasTrans;
325 blas::dgemm_call< float_t >( p );
334 template<
size_t M,
size_t K,
size_t N,
typename float_t >
335 template<
size_t I2,
size_t I3 >
337 blas_dgemm< M, K, N, float_t >::compute(
const tensor3< M, I2, I3, float_t >& A_, matrix_out_t& C_ )
340 tensor3< M, I2, I3, float_t > AA( A_ ) ;
343 p.trans_b = CblasTrans;
344 p.a = AA.get_array_ptr();
345 p.b = AA.get_array_ptr();
349 blas::dgemm_call< float_t >( p );
356 template<
size_t M,
size_t K,
size_t N,
typename float_t >
358 blas_dgemm< M, K, N, float_t >::compute_t(
const matrix_right_t& B_, matrix_out_t& C_ )
361 matrix_right_t* BB =
new matrix_right_t( B_ );
364 p.trans_a = CblasTrans;
370 blas::dgemm_call< float_t >( p );
379 template<
size_t M,
size_t K,
size_t N,
typename float_t >
381 blas_dgemm< M, K, N, float_t >::compute_bt(
382 const matrix_left_t& A_,
383 const matrix_right_t_t& Bt_,
387 matrix_left_t* AA =
new matrix_left_t( A_ );
388 matrix_right_t_t* BB =
new matrix_right_t_t( Bt_ );
391 p.trans_b = CblasTrans;
397 blas::dgemm_call< float_t >( p );
407 template<
size_t M,
size_t K,
size_t N,
typename float_t >
409 blas_dgemm< M, K, N, float_t >::compute_t(
410 const matrix_left_t_t& At_,
411 const matrix_right_t_t& Bt_,
415 matrix_left_t_t* AA =
new matrix_left_t_t( At_ );
416 matrix_right_t_t* BB =
new matrix_right_t_t( Bt_ );
419 p.trans_a = CblasTrans;
420 p.trans_b = CblasTrans;
427 blas::dgemm_call< float_t >( p );
437 template<
size_t M,
size_t K,
size_t N,
typename float_t >
439 blas_dgemm< M, K, N, float_t >::compute_vv_outer(
440 const vector_left_t& A_,
441 const vector_right_t& B_,
445 vector_left_t* AA =
new vector_left_t( A_ );
446 vector_right_t* BB =
new vector_right_t( B_ );
449 p.trans_a = CblasTrans;
455 blas::dgemm_call< float_t >( p );