vmmlib  1.7.0
 All Classes Namespaces Functions Pages
blas_daxpy.hpp
1 /*
2  * Copyright (c) 2006-2014, Visualization and Multimedia Lab,
3  * University of Zurich <http://vmml.ifi.uzh.ch>,
4  * Eyescale Software GmbH,
5  * Blue Brain Project, EPFL
6  *
7  * This file is part of VMMLib <https://github.com/VMML/vmmlib/>
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions are met:
11  *
12  * Redistributions of source code must retain the above copyright notice, this
13  * list of conditions and the following disclaimer. Redistributions in binary
14  * form must reproduce the above copyright notice, this list of conditions and
15  * the following disclaimer in the documentation and/or other materials provided
16  * with the distribution. Neither the name of the Visualization and Multimedia
17  * Lab, University of Zurich nor the names of its contributors may be used to
18  * endorse or promote products derived from this software without specific prior
19  * written permission.
20  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
24  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
25  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
26  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
27  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
28  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
29  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
30  * POSSIBILITY OF SUCH DAMAGE.
31  */
32 #ifndef __VMML__VMMLIB_BLAS_DAXPY__HPP__
33 #define __VMML__VMMLIB_BLAS_DAXPY__HPP__
34 
35 
36 #include <vmmlib/vector.hpp>
37 #include <vmmlib/matrix.hpp>
38 #include <vmmlib/exception.hpp>
39 #include <vmmlib/blas_includes.hpp>
40 #include <vmmlib/blas_types.hpp>
41 #ifdef VMMLIB_USE_OPENMP
42 # include <omp.h>
43 #endif
44 
69 namespace vmml
70 {
71 
72  namespace blas
73  {
74 
75 
76 #if 0
77  /* Subroutine */
78  void cblas_daxpy(const int N, const double alpha, const double *X,
79  const int incX, double *Y, const int incY);
80 
81 #endif
82 
83  template< typename float_t >
84  struct daxpy_params
85  {
86  blas_int n;
87  float_t alpha;
88  float_t* x;
89  blas_int inc_x;
90  float_t* y;
91  blas_int inc_y;
92 
93  friend std::ostream& operator << ( std::ostream& os,
94  const daxpy_params< float_t >& p )
95  {
96  os
97  << " (1)\tn " << p.n << std::endl
98  << " (2)\talpha " << p.alpha << std::endl
99  << " (3)\tx " << p.x << std::endl
100  << " (4)\tincX " << p.inc_x << std::endl
101  << " (5)\ty " << p.y << std::endl
102  << " (6)\tincY " << p.inc_y << std::endl
103  << std::endl;
104  return os;
105  }
106 
107  };
108 
109 
110 
111  template< typename float_t >
112  inline void
113  daxpy_call( daxpy_params< float_t >& )
114  {
115  VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
116  }
117 
118 
119  template<>
120  inline void
121  daxpy_call( daxpy_params< float >& p )
122  {
123  //std::cout << "calling blas saxpy (single precision) " << std::endl;
124  cblas_saxpy(
125  p.n,
126  p.alpha,
127  p.x,
128  p.inc_x,
129  p.y,
130  p.inc_y
131  );
132  }
133 
134  template<>
135  inline void
136  daxpy_call( daxpy_params< double >& p )
137  {
138  //std::cout << "calling blas daxpy (double precision) " << std::endl;
139  cblas_daxpy(
140  p.n,
141  p.alpha,
142  p.x,
143  p.inc_x,
144  p.y,
145  p.inc_y
146  );
147  }
148 
149  } // namespace blas
150 
151 
152 
153  template< size_t M, typename float_t >
154  struct blas_daxpy
155  {
156 
158 
159  blas_daxpy();
160  ~blas_daxpy() {};
161 
162  bool compute( const float_t a_, const vector_t& B_, vector_t& C_ );
163 
164  template< size_t K, size_t N >
165  bool compute_mmm( const matrix< M, K, float_t >& left_m_,
166  const matrix< K, N, float_t >& right_m_,
167  matrix< M, N, float_t >& res_m_ );
168 
169  template< size_t K >
170  bool compute_mmm( const matrix< M, K, float_t >& left_m_,
171  matrix< M, M, float_t >& res_m_ );
172 
173 
175 
176  const blas::daxpy_params< float_t >& get_params(){ return p; };
177 
178 
179  }; // struct blas_daxpy
180 
181 
182  template< size_t M, typename float_t >
184  {
185  p.n = M;
186  p.alpha = 0;
187  p.x = 0;
188  p.inc_x = 1;
189  p.y = 0;
190  p.inc_y = 1;
191  }
192 
193 
194  template< size_t M, typename float_t >
195  bool
196  blas_daxpy< M, float_t >::compute( const float_t a_, const vector_t& B_, vector_t& C_ )
197  {
198  // blas needs non-const data
199  vector_t* BB = new vector_t( B_ );
200 
201  C_.set(0);
202 
203  p.alpha = a_;
204  p.x = BB->array;
205  p.y = C_.array;
206 
207  blas::daxpy_call< float_t >( p );
208 
209  //std::cout << p << std::endl; //debug
210 
211  delete BB;
212 
213  return true;
214  }
215 
216  template< size_t M, typename float_t >
217  template< size_t K, size_t N >
218  bool
219  blas_daxpy< M, float_t >::compute_mmm( const matrix< M, K, float_t >& left_m_,
220  const matrix< K, N, float_t >& right_m_,
221  matrix< M, N, float_t >& res_m_ )
222  {
223  for ( int n = 0; n < (int)N; ++n )
224  {
225  vector_t* final_col = new vector_t;
226  final_col->set(0);
227 
228  for ( int k = 0; k < (int)K; ++k )
229  {
230  vector_t* in_col = new vector_t;
231  vector_t* out_col = new vector_t;
232  float_t a_val = right_m_.at( k, n );
233  left_m_.get_column( k, *in_col );
234 
235  compute( a_val, *in_col, *out_col );
236 
237  *final_col += *out_col;
238 
239  delete in_col;
240  delete out_col;
241  }
242 
243  res_m_.set_column( n, *final_col );
244 
245  delete final_col;
246  }
247 
248 
249  return true;
250  }
251 
252 
253  template< size_t M, typename float_t >
254  template< size_t K >
255  bool
256  blas_daxpy< M, float_t >::compute_mmm( const matrix< M, K, float_t >& left_m_,
257  matrix< M, M, float_t >& res_m_ )
258  {
259 #pragma omp parallel for
260  for ( int n = 0; n < (int)M; ++n )
261  {
262  vector_t* final_col = new vector_t;
263  final_col->set(0);
264 
265 #pragma omp parallel for
266  for ( int k = 0; k < (int)K; ++k )
267  {
268  vector_t* in_col = new vector_t;
269  vector_t* out_col = new vector_t;
270  float_t a_val = left_m_.at( n,k ); //reversed (k,n), because take value from transposed matrix left_m_
271  left_m_.get_column( k, *in_col );
272 
273  compute( a_val, *in_col, *out_col );
274 
275  *final_col += *out_col;
276 
277  delete in_col;
278  delete out_col;
279  }
280 
281  res_m_.set_column( n, *final_col );
282 
283  delete final_col;
284  }
285 
286 
287  return true;
288  }
289 
290 
291 
292 } // namespace vmml
293 
294 #endif