vmmlib  1.7.0
 All Classes Namespaces Functions Pages
lapack_linear_least_squares.hpp
1 /*
2  * Copyright (c) 2006-2014, Visualization and Multimedia Lab,
3  * University of Zurich <http://vmml.ifi.uzh.ch>,
4  * Eyescale Software GmbH,
5  * Blue Brain Project, EPFL
6  *
7  * This file is part of VMMLib <https://github.com/VMML/vmmlib/>
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions are met:
11  *
12  * Redistributions of source code must retain the above copyright notice, this
13  * list of conditions and the following disclaimer. Redistributions in binary
14  * form must reproduce the above copyright notice, this list of conditions and
15  * the following disclaimer in the documentation and/or other materials provided
16  * with the distribution. Neither the name of the Visualization and Multimedia
17  * Lab, University of Zurich nor the names of its contributors may be used to
18  * endorse or promote products derived from this software without specific prior
19  * written permission.
20  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
24  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
25  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
26  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
27  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
28  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
29  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
30  * POSSIBILITY OF SUCH DAMAGE.
31  */
32 #ifndef __VMML__VMMLIB_LAPACK_LINEAR_LEAST_SQUARES__HPP__
33 #define __VMML__VMMLIB_LAPACK_LINEAR_LEAST_SQUARES__HPP__
34 
35 #include <vmmlib/matrix.hpp>
36 #include <vmmlib/vector.hpp>
37 #include <vmmlib/exception.hpp>
38 
39 #include <vmmlib/lapack_types.hpp>
40 #include <vmmlib/lapack_includes.hpp>
41 
42 #include <string>
43 
54 namespace vmml
55 {
56 
57 // XYYZZZ
58 // X = data type: S - float, D - double
59 // YY = matrix type, GE - general, TR - triangular
60 // ZZZ = function name
61 
62 namespace lapack
63 {
64 
65 //
66 //
67 // SGELS/DGELS
68 //
69 //
70 
71 
72 // parameter struct
73 template< typename float_t >
75 {
76  char trans; // 'N'->A, 'T'->Atransposed
77  lapack_int m; // number of rows, M >= 0
78  lapack_int n; // number of columns, N >= 0
79  lapack_int nrhs; // number of columns of B/X
80  float_t* a; // input A
81  lapack_int lda; // leading dimension of A (number of rows)
82  float_t* b; // input B, output X
83  lapack_int ldb; // leading dimension of b
84  float_t* work; // workspace
85  lapack_int lwork; // workspace size
86  lapack_int info; // 'return' value
87 
88  friend std::ostream& operator << ( std::ostream& os,
90  {
91  os
92  << " m " << p.m
93  << " n " << p.n
94  << " nrhs " << p.nrhs
95  << " lda " << p.lda
96  << " ldb " << p.ldb
97  << " lwork " << p.lwork
98  << " info " << p.info
99  << std::endl;
100  return os;
101  }
102 
103 };
104 
105 // call wrappers
106 
107 #if 0
108 void dgels_(const char *trans, const int *M, const int *N, const int *nrhs,
109  double *A, const int *lda, double *b, const int *ldb, double *work,
110  const int * lwork, int *info);
111 #endif
112 
113 template< typename float_t >
114 inline void
115 llsq_call_xgels( llsq_params_xgels< float_t >& )
116 {
117  VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
118 }
119 
120 template<>
121 inline void
122 llsq_call_xgels( llsq_params_xgels< float >& p )
123 {
124  sgels_(
125  &p.trans,
126  &p.m,
127  &p.n,
128  &p.nrhs,
129  p.a,
130  &p.lda,
131  p.b,
132  &p.ldb,
133  p.work,
134  &p.lwork,
135  &p.info
136  );
137 }
138 
139 template<>
140 inline void
141 llsq_call_xgels( llsq_params_xgels< double >& p )
142 {
143  dgels_(
144  &p.trans,
145  &p.m,
146  &p.n,
147  &p.nrhs,
148  p.a,
149  &p.lda,
150  p.b,
151  &p.ldb,
152  p.work,
153  &p.lwork,
154  &p.info
155  );
156 
157 }
158 
159 
160 template< size_t M, size_t N, typename float_t >
162 {
163  bool compute(
164  const matrix< M, N, float_t >& A,
165  const vector< M, float_t >& B,
167 
170 
171  const lapack::llsq_params_xgels< float_t >& get_params(){ return p; };
172 
173  matrix< M, N, float_t >& get_factorized_A() { return _A; }
174 
175 protected:
178 
180 
181 };
182 
183 
184 
185 template< size_t M, size_t N, typename float_t >
186 bool
188  const matrix< M, N, float_t >& A,
189  const vector< M, float_t >& B,
191 {
192  _A = A;
193  _b = B;
194 
195  llsq_call_xgels( p );
196 
197  // success
198  if ( p.info == 0 )
199  {
200  for( size_t index = 0; index < N; ++index )
201  {
202  x( index ) = _b( index );
203  }
204 
205  return true;
206  }
207  if ( p.info < 0 )
208  {
209  VMMLIB_ERROR( "xGELS - invalid argument.", VMMLIB_HERE );
210  }
211  else
212  {
213  std::cout << "A\n" << A << std::endl;
214  std::cout << "B\n" << B << std::endl;
215 
216  VMMLIB_ERROR( "least squares solution could not be computed.",
217  VMMLIB_HERE );
218  }
219  return false;
220 }
221 
222 
223 
224 template< size_t M, size_t N, typename float_t >
225 linear_least_squares_xgels< M, N, float_t >::
226 linear_least_squares_xgels()
227 {
228  p.trans = 'N';
229  p.m = M;
230  p.n = N;
231  p.nrhs = 1;
232  p.a = _A.array;
233  p.lda = M;
234  p.b = _b.array;
235  p.ldb = M;
236  p.work = new float_t();
237  p.lwork = -1;
238 
239  // workspace query
240  llsq_call_xgels( p );
241 
242  p.lwork = static_cast< lapack_int > ( p.work[0] );
243  delete p.work;
244 
245  p.work = new float_t[ p.lwork ];
246 }
247 
248 
249 
250 template< size_t M, size_t N, typename float_t >
251 linear_least_squares_xgels< M, N, float_t >::
252 ~linear_least_squares_xgels()
253 {
254  delete[] p.work;
255 }
256 
257 
258 
259 //
260 //
261 // SGESV/DGESV
262 //
263 //
264 
265 template< typename float_t >
267 {
268  lapack_int n; // order of matrix A = M * N
269  lapack_int nrhs; // number of columns of B
270  float_t* a; // input A, output P*L*U
271  lapack_int lda; // leading dimension of A (for us: number of rows)
272  lapack_int* ipiv; // pivot indices, integer array of size N
273  float_t* b; // input b, output X
274  lapack_int ldb; // leading dimension of b
275  lapack_int info;
276 
277  friend std::ostream& operator << ( std::ostream& os,
279  {
280  os
281  << "n " << p.n
282  << " nrhs " << p.nrhs
283  << " lda " << p.lda
284  << " ldb " << p.ldvt
285  << " info " << p.info
286  << std::endl;
287  return os;
288  }
289 
290 };
291 
292 
293 #if 0
294 /* Subroutine */ int dgesv_(integer *n, integer *nrhs, doublereal *a, integer
295  *lda, integer *ipiv, doublereal *b, integer *ldb, integer *info);
296 #endif
297 
298 
299 template< typename float_t >
300 inline void
301 llsq_call_xgesv( llsq_params_xgesv< float_t >& )
302 {
303  VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
304 }
305 
306 
307 template<>
308 inline void
309 llsq_call_xgesv( llsq_params_xgesv< float >& p )
310 {
311  sgesv_(
312  &p.n,
313  &p.nrhs,
314  p.a,
315  &p.lda,
316  p.ipiv,
317  p.b,
318  &p.ldb,
319  &p.info
320  );
321 
322 }
323 
324 
325 template<>
326 inline void
327 llsq_call_xgesv( llsq_params_xgesv< double >& p )
328 {
329  dgesv_(
330  &p.n,
331  &p.nrhs,
332  p.a,
333  &p.lda,
334  p.ipiv,
335  p.b,
336  &p.ldb,
337  &p.info
338  );
339 }
340 
341 
342 template< size_t M, size_t N, typename float_t >
344 {
345  // computes x ( Ax = b ). x replaces b on output.
346  void compute(
349  );
350 
353 
354  const lapack::llsq_params_xgesv< float_t >& get_params() { return p; }
355 
357 
358 }; // struct lapack_linear_least_squares
359 
360 
361 template< size_t M, size_t N, typename float_t >
362 void
364 compute(
367  )
368 {
369  p.a = A.array;
370  p.b = b.array;
371 
372  lapack::llsq_call_xgesv( p );
373 
374  if ( p.info != 0 )
375  {
376  if ( p.info < 0 )
377  VMMLIB_ERROR( "invalid value in input matrix", VMMLIB_HERE );
378  else
379  VMMLIB_ERROR( "factor U is exactly singular, solution could not be computed.", VMMLIB_HERE );
380  }
381 }
382 
383 
384 
385 template< size_t M, size_t N, typename float_t >
386 linear_least_squares_xgesv< M, N, float_t >::
387 linear_least_squares_xgesv()
388 {
389  p.n = N;
390  p.nrhs = M;
391  p.lda = N;
392  p.ldb = N;
393  p.ipiv = new lapack_int[ N ];
394 
395 }
396 
397 
398 
399 template< size_t M, size_t N, typename float_t >
400 linear_least_squares_xgesv< M, N, float_t >::
401 ~linear_least_squares_xgesv()
402 {
403  delete[] p.ipiv;
404 }
405 
406 
407 } // namespace lapack
408 
409 } // namespace vmml
410 
411 #endif