32 #ifndef __VMML__VMMLIB_BLAS_DAXPY__HPP__
33 #define __VMML__VMMLIB_BLAS_DAXPY__HPP__
36 #include <vmmlib/vector.hpp>
37 #include <vmmlib/matrix.hpp>
38 #include <vmmlib/exception.hpp>
39 #include <vmmlib/blas_includes.hpp>
40 #include <vmmlib/blas_types.hpp>
41 #ifdef VMMLIB_USE_OPENMP
78 void cblas_daxpy(
const int N,
const double alpha,
const double *X,
79 const int incX,
double *Y,
const int incY);
83 template<
typename float_t >
93 friend std::ostream& operator << ( std::ostream& os,
97 <<
" (1)\tn " << p.n << std::endl
98 <<
" (2)\talpha " << p.alpha << std::endl
99 <<
" (3)\tx " << p.x << std::endl
100 <<
" (4)\tincX " << p.inc_x << std::endl
101 <<
" (5)\ty " << p.y << std::endl
102 <<
" (6)\tincY " << p.inc_y << std::endl
111 template<
typename float_t >
115 VMMLIB_ERROR(
"not implemented for this type.", VMMLIB_HERE );
121 daxpy_call( daxpy_params< float >& p )
136 daxpy_call( daxpy_params< double >& p )
153 template<
size_t M,
typename float_t >
164 template<
size_t K,
size_t N >
182 template<
size_t M,
typename float_t >
194 template<
size_t M,
typename float_t >
196 blas_daxpy< M, float_t >::compute(
const float_t a_,
const vector_t& B_, vector_t& C_ )
199 vector_t* BB =
new vector_t( B_ );
207 blas::daxpy_call< float_t >( p );
216 template<
size_t M,
typename float_t >
217 template<
size_t K,
size_t N >
219 blas_daxpy< M, float_t >::compute_mmm(
const matrix< M, K, float_t >& left_m_,
220 const matrix< K, N, float_t >& right_m_,
223 for (
int n = 0; n < (int)N; ++n )
225 vector_t* final_col =
new vector_t;
228 for (
int k = 0; k < (int)K; ++k )
230 vector_t* in_col =
new vector_t;
231 vector_t* out_col =
new vector_t;
232 float_t a_val = right_m_.at( k, n );
233 left_m_.get_column( k, *in_col );
235 compute( a_val, *in_col, *out_col );
237 *final_col += *out_col;
243 res_m_.set_column( n, *final_col );
253 template<
size_t M,
typename float_t >
256 blas_daxpy< M, float_t >::compute_mmm(
const matrix< M, K, float_t >& left_m_,
257 matrix< M, M, float_t >& res_m_ )
259 #pragma omp parallel for
260 for (
int n = 0; n < (int)M; ++n )
262 vector_t* final_col =
new vector_t;
265 #pragma omp parallel for
266 for (
int k = 0; k < (int)K; ++k )
268 vector_t* in_col =
new vector_t;
269 vector_t* out_col =
new vector_t;
270 float_t a_val = left_m_.at( n,k );
271 left_m_.get_column( k, *in_col );
273 compute( a_val, *in_col, *out_col );
275 *final_col += *out_col;
281 res_m_.set_column( n, *final_col );