vmmlib  1.7.0
 All Classes Namespaces Functions Pages
blas_dgemm.hpp
1 /*
2  * Copyright (c) 2006-2014, Visualization and Multimedia Lab,
3  * University of Zurich <http://vmml.ifi.uzh.ch>,
4  * Eyescale Software GmbH,
5  * Blue Brain Project, EPFL
6  *
7  * This file is part of VMMLib <https://github.com/VMML/vmmlib/>
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions are met:
11  *
12  * Redistributions of source code must retain the above copyright notice, this
13  * list of conditions and the following disclaimer. Redistributions in binary
14  * form must reproduce the above copyright notice, this list of conditions and
15  * the following disclaimer in the documentation and/or other materials provided
16  * with the distribution. Neither the name of the Visualization and Multimedia
17  * Lab, University of Zurich nor the names of its contributors may be used to
18  * endorse or promote products derived from this software without specific prior
19  * written permission.
20  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
24  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
25  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
26  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
27  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
28  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
29  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
30  * POSSIBILITY OF SUCH DAMAGE.
31  */
32 #ifndef __VMML__VMMLIB_BLAS_DGEMM__HPP__
33 #define __VMML__VMMLIB_BLAS_DGEMM__HPP__
34 
35 
36 #include <vmmlib/matrix.hpp>
37 #include <vmmlib/tensor3.hpp>
38 #include <vmmlib/exception.hpp>
39 #include <vmmlib/blas_includes.hpp>
40 #include <vmmlib/blas_types.hpp>
41 
77 namespace vmml
78 {
79 
80  namespace blas
81  {
82 
83 
84 #if 0
85  /* Subroutine */
86  void cblas_dgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB,
87  blasint M, blasint N, blasint K,
88  double alpha, double *A, blasint lda, double *B, blasint ldb, double beta, double *C, blasint ldc);
89 
90 #endif
91 
92  template< typename float_t >
93  struct dgemm_params
94  {
95  CBLAS_ORDER order;
96  CBLAS_TRANSPOSE trans_a;
97  CBLAS_TRANSPOSE trans_b;
98  blas_int m;
99  blas_int n;
100  blas_int k;
101  float_t alpha;
102  float_t* a;
103  blas_int lda; //leading dimension of input array matrix left
104  float_t* b;
105  blas_int ldb; //leading dimension of input array matrix right
106  float_t beta;
107  float_t* c;
108  blas_int ldc; //leading dimension of output array matrix right
109 
110  friend std::ostream& operator << ( std::ostream& os,
111  const dgemm_params< float_t >& p )
112  {
113  os
114  << " (1)\torder " << p.order << std::endl
115  << " (2)\ttrans_a " << p.trans_a << std::endl
116  << " (3)\ttrans_b " << p.trans_b << std::endl
117  << " (4)\tm " << p.m << std::endl
118  << " (6)\tn " << p.n << std::endl
119  << " (5)\tk " << p.k << std::endl
120  << " (7)\talpha " << p.alpha << std::endl
121  << " (8)\ta " << p.a << std::endl
122  << " (9)\tlda " << p.lda << std::endl
123  << " (10)\tb " << p.b << std::endl
124  << " (11)\tldb " << p.ldb << std::endl
125  << " (12)\tbeta " << p.beta << std::endl
126  << " (13)\tc " << p.c << std::endl
127  << " (14)\tldc " << p.ldc << std::endl
128  << std::endl;
129  return os;
130  }
131 
132  };
133 
134 
135 
136  template< typename float_t >
137  inline void
138  dgemm_call( dgemm_params< float_t >& )
139  {
140  VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
141  }
142 
143 
144  template<>
145  inline void
146  dgemm_call( dgemm_params< float >& p )
147  {
148  //std::cout << "calling blas sgemm (single precision) " << std::endl;
149  cblas_sgemm(
150  p.order,
151  p.trans_a,
152  p.trans_b,
153  p.m,
154  p.n,
155  p.k,
156  p.alpha,
157  p.a,
158  p.lda,
159  p.b,
160  p.ldb,
161  p.beta,
162  p.c,
163  p.ldc
164  );
165 
166  }
167 
168  template<>
169  inline void
170  dgemm_call( dgemm_params< double >& p )
171  {
172  //std::cout << "calling blas dgemm (double precision) " << std::endl;
173  cblas_dgemm(
174  p.order,
175  p.trans_a,
176  p.trans_b,
177  p.m,
178  p.n,
179  p.k,
180  p.alpha,
181  p.a,
182  p.lda,
183  p.b,
184  p.ldb,
185  p.beta,
186  p.c,
187  p.ldc
188  );
189  }
190 
191  } // namespace blas
192 
193 
194 
195  template< size_t M, size_t K, size_t N, typename float_t >
196  struct blas_dgemm
197  {
198 
206 
207  blas_dgemm();
208  ~blas_dgemm() {};
209 
210  bool compute( const matrix_left_t& A_, const matrix_right_t& B_, matrix_out_t& C_ );
211  bool compute( const matrix_left_t& A_, matrix_out_t& C_ );
212 
213  // dgemms with tensor3 input works for frontal tensor unfolding
214  //I2*I3 = K;
215  template< size_t I2, size_t I3 >
216  bool compute( const tensor3< M, I2, I3, float_t >& A_, const matrix_right_t& B_, matrix_out_t& C_ );
217  //I2*I3 = K;
218  template< size_t I2, size_t I3 >
219  bool compute( const tensor3< M, I2, I3, float_t >& A_, matrix_out_t& C_ );
220 
221  bool compute_t( const matrix_right_t& B_, matrix_out_t& C_ );
222  bool compute_bt( const matrix_left_t& A_, const matrix_right_t_t& Bt_, matrix_out_t& C_ );
223  bool compute_t( const matrix_left_t_t& A_, const matrix_right_t_t& B_, matrix_out_t& C_ );
224  bool compute_vv_outer( const vector_left_t& A_, const vector_right_t& B_, matrix_out_t& C_ );
225 
226 
228 
229  const blas::dgemm_params< float_t >& get_params(){ return p; };
230 
231 
232  }; // struct blas_dgemm
233 
234 
235  template< size_t M, size_t K, size_t N, typename float_t >
237  {
238  p.order = CblasColMajor; //
239  p.trans_a = CblasNoTrans;
240  p.trans_b = CblasNoTrans;
241  p.m = M;
242  p.n = N;
243  p.k = K;
244  p.alpha = 1;
245  p.a = 0;
246  p.lda = M;
247  p.b = 0;
248  p.ldb = K; //no transpose
249  p.beta = 0;
250  p.c = 0;
251  p.ldc = M;
252  }
253 
254 
255 
256  template< size_t M, size_t K, size_t N, typename float_t >
257  bool
258  blas_dgemm< M, K, N, float_t >::compute(
259  const matrix_left_t& A_,
260  const matrix_right_t& B_,
261  matrix_out_t& C_
262  )
263  {
264  // blas needs non-const data
265  matrix_left_t* AA = new matrix_left_t( A_ );
266  matrix_right_t* BB = new matrix_right_t( B_ );
267  C_.zero();
268 
269  p.a = AA->array;
270  p.b = BB->array;
271  p.c = C_.array;
272 
273  blas::dgemm_call< float_t >( p );
274 
275  //std::cout << p << std::endl; //debug
276 
277  delete AA;
278  delete BB;
279 
280  return true;
281  }
282 
283  template< size_t M, size_t K, size_t N, typename float_t >
284  template< size_t I2, size_t I3 >
285  bool
286  blas_dgemm< M, K, N, float_t >::compute(
287  const tensor3< M, I2, I3, float_t >& A_,
288  const matrix_right_t& B_,
289  matrix_out_t& C_
290  )
291  {
292  // blas needs non-const data
293  tensor3< M, I2, I3, float_t > AA( A_ );
294  matrix_right_t* BB = new matrix_right_t( B_ );
295  C_.zero();
296 
297  p.a = AA.get_array_ptr();
298  p.b = BB->array;
299  p.c = C_.array;
300 
301  blas::dgemm_call< float_t >( p );
302 
303  //std::cout << p << std::endl; //debug
304 
305  delete BB;
306 
307  return true;
308  }
309 
310 
311  template< size_t M, size_t K, size_t N, typename float_t >
312  bool
313  blas_dgemm< M, K, N, float_t >::compute( const matrix_left_t& A_, matrix_out_t& C_ )
314  {
315  // blas needs non-const data
316  matrix_left_t* AA = new matrix_left_t( A_ );
317  C_.zero();
318 
319  p.trans_b = CblasTrans;
320  p.a = AA->array;
321  p.b = AA->array;
322  p.ldb = N;
323  p.c = C_.array;
324 
325  blas::dgemm_call< float_t >( p );
326 
327  //std::cout << p << std::endl; //debug
328 
329  delete AA;
330 
331  return true;
332  }
333 
334  template< size_t M, size_t K, size_t N, typename float_t >
335  template< size_t I2, size_t I3 >
336  bool
337  blas_dgemm< M, K, N, float_t >::compute( const tensor3< M, I2, I3, float_t >& A_, matrix_out_t& C_ )
338  {
339  // blas needs non-const data
340  tensor3< M, I2, I3, float_t > AA( A_ ) ;
341  C_.zero();
342 
343  p.trans_b = CblasTrans;
344  p.a = AA.get_array_ptr();
345  p.b = AA.get_array_ptr();
346  p.ldb = N;
347  p.c = C_.array;
348 
349  blas::dgemm_call< float_t >( p );
350 
351  //std::cout << p << std::endl; //debug
352 
353  return true;
354  }
355 
356  template< size_t M, size_t K, size_t N, typename float_t >
357  bool
358  blas_dgemm< M, K, N, float_t >::compute_t( const matrix_right_t& B_, matrix_out_t& C_ )
359  {
360  // blas needs non-const data
361  matrix_right_t* BB = new matrix_right_t( B_ );
362  C_.zero();
363 
364  p.trans_a = CblasTrans;
365  p.a = BB->array;
366  p.b = BB->array;
367  p.lda = K;
368  p.c = C_.array;
369 
370  blas::dgemm_call< float_t >( p );
371 
372  //std::cout << p << std::endl; //debug
373 
374  delete BB;
375 
376  return true;
377  }
378 
379  template< size_t M, size_t K, size_t N, typename float_t >
380  bool
381  blas_dgemm< M, K, N, float_t >::compute_bt(
382  const matrix_left_t& A_,
383  const matrix_right_t_t& Bt_,
384  matrix_out_t& C_ )
385  {
386  // blas needs non-const data
387  matrix_left_t* AA = new matrix_left_t( A_ );
388  matrix_right_t_t* BB = new matrix_right_t_t( Bt_ );
389  C_.zero();
390 
391  p.trans_b = CblasTrans;
392  p.a = AA->array;
393  p.b = BB->array;
394  p.c = C_.array;
395  p.ldb = N;
396 
397  blas::dgemm_call< float_t >( p );
398 
399  //std::cout << p << std::endl; //debug
400 
401  delete AA;
402  delete BB;
403 
404  return true;
405  }
406 
407  template< size_t M, size_t K, size_t N, typename float_t >
408  bool
409  blas_dgemm< M, K, N, float_t >::compute_t(
410  const matrix_left_t_t& At_,
411  const matrix_right_t_t& Bt_,
412  matrix_out_t& C_ )
413  {
414  // blas needs non-const data
415  matrix_left_t_t* AA = new matrix_left_t_t( At_ );
416  matrix_right_t_t* BB = new matrix_right_t_t( Bt_ );
417  C_.zero();
418 
419  p.trans_a = CblasTrans;
420  p.trans_b = CblasTrans;
421  p.a = AA->array;
422  p.b = BB->array;
423  p.c = C_.array;
424  p.ldb = N;
425  p.lda = K;
426 
427  blas::dgemm_call< float_t >( p );
428 
429  //std::cout << p << std::endl; //debug
430 
431  delete AA;
432  delete BB;
433 
434  return true;
435  }
436 
437  template< size_t M, size_t K, size_t N, typename float_t >
438  bool
439  blas_dgemm< M, K, N, float_t >::compute_vv_outer(
440  const vector_left_t& A_,
441  const vector_right_t& B_,
442  matrix_out_t& C_ )
443  {
444  // blas needs non-const data
445  vector_left_t* AA = new vector_left_t( A_ );
446  vector_right_t* BB = new vector_right_t( B_ );
447  C_.zero();
448 
449  p.trans_a = CblasTrans;
450  p.a = AA->array;
451  p.b = BB->array;
452  p.c = C_.array;
453  p.lda = K;
454 
455  blas::dgemm_call< float_t >( p );
456 
457  //std::cout << p << std::endl; //debug
458 
459  delete AA;
460  delete BB;
461 
462  return true;
463  }
464 
465 
466 } // namespace vmml
467 
468 #endif