vmmlib  1.7.0
 All Classes Namespaces Functions Pages
blas_dot.hpp
1 /*
2  * Copyright (c) 2006-2014, Visualization and Multimedia Lab,
3  * University of Zurich <http://vmml.ifi.uzh.ch>,
4  * Eyescale Software GmbH,
5  * Blue Brain Project, EPFL
6  *
7  * This file is part of VMMLib <https://github.com/VMML/vmmlib/>
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions are met:
11  *
12  * Redistributions of source code must retain the above copyright notice, this
13  * list of conditions and the following disclaimer. Redistributions in binary
14  * form must reproduce the above copyright notice, this list of conditions and
15  * the following disclaimer in the documentation and/or other materials provided
16  * with the distribution. Neither the name of the Visualization and Multimedia
17  * Lab, University of Zurich nor the names of its contributors may be used to
18  * endorse or promote products derived from this software without specific prior
19  * written permission.
20  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
24  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
25  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
26  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
27  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
28  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
29  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
30  * POSSIBILITY OF SUCH DAMAGE.
31  */
32 #ifndef __VMML__VMMLIB_BLAS_DOT__HPP__
33 #define __VMML__VMMLIB_BLAS_DOT__HPP__
34 
35 
36 #include <vmmlib/vector.hpp>
37 #include <vmmlib/exception.hpp>
38 #include <vmmlib/blas_includes.hpp>
39 #include <vmmlib/blas_types.hpp>
40 
63 namespace vmml
64 {
65 namespace blas
66 {
67  template< typename float_t >
68  struct dot_params
69  {
70  blas_int n;
71  float_t* x;
72  blas_int inc_x;
73  float_t* y;
74  blas_int inc_y;
75 
76  friend std::ostream& operator << ( std::ostream& os,
77  const dot_params< float_t >& p )
78  {
79  os
80  << " (1)\tn " << p.n << std::endl
81  << " (2)\tx " << p.x << std::endl
82  << " (3)\tincX " << p.inc_x << std::endl
83  << " (4)\ty " << p.y << std::endl
84  << " (5)\tincY " << p.inc_y << std::endl
85  << std::endl;
86  return os;
87  }
88 
89  };
90 
91 
92 
93  template< typename float_t >
94  inline float_t
95  dot_call( dot_params< float_t >& )
96  {
97  VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
98  }
99 
100 
101  template<>
102  inline float
103  dot_call( dot_params< float >& p )
104  {
105  //std::cout << "calling blas sdot (single precision) " << std::endl;
106  float vvi = cblas_sdot(
107  p.n,
108  p.x,
109  p.inc_x,
110  p.y,
111  p.inc_y
112  );
113  return vvi;
114  }
115 
116  template<>
117  inline double
118  dot_call( dot_params< double >& p )
119  {
120  //std::cout << "calling blas ddot (double precision) " << std::endl;
121  double vvi = cblas_ddot(
122  p.n,
123  p.x,
124  p.inc_x,
125  p.y,
126  p.inc_y
127  );
128  return vvi;
129  }
130 
131 } // namespace blas
132 
133 
134 
135 template< size_t M, typename float_t >
136 struct blas_dot
137 {
138 
140 
141  blas_dot();
142  ~blas_dot() {};
143 
144  bool compute( const vector_t& A_, const vector_t& B_, float_t& dot_prod_ );
145 
146 
148 
149  const blas::dot_params< float_t >& get_params(){ return p; };
150 
151 
152 }; // struct blas_dot
153 
154 
155 template< size_t M, typename float_t >
157 {
158  p.n = M;
159  p.x = 0;
160  p.inc_x = 1;
161  p.y = 0;
162  p.inc_y = 1;
163 }
164 
165 
166 template< size_t M, typename float_t >
167 bool
168 blas_dot< M, float_t >::compute( const vector_t& A_, const vector_t& B_, float_t& dot_prod_ )
169 {
170  // blas needs non-const data
171  vector_t* AA = new vector_t( A_ );
172  vector_t* BB = new vector_t( B_ );
173 
174  p.x = AA->array;
175  p.y = BB->array;
176 
177  dot_prod_ = blas::dot_call< float_t >( p );
178 
179  //std::cout << dot_prod_ << std::endl; //debug
180 
181  delete AA;
182  delete BB;
183 
184  return true;
185 }
186 
187 } // namespace vmml
188 
189 #endif