vmmlib  1.7.0
 All Classes Namespaces Functions Pages
lapack_svd.hpp
1 /*
2  * Copyright (c) 2006-2014, Visualization and Multimedia Lab,
3  * University of Zurich <http://vmml.ifi.uzh.ch>,
4  * Eyescale Software GmbH,
5  * Blue Brain Project, EPFL
6  *
7  * This file is part of VMMLib <https://github.com/VMML/vmmlib/>
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions are met:
11  *
12  * Redistributions of source code must retain the above copyright notice, this
13  * list of conditions and the following disclaimer. Redistributions in binary
14  * form must reproduce the above copyright notice, this list of conditions and
15  * the following disclaimer in the documentation and/or other materials provided
16  * with the distribution. Neither the name of the Visualization and Multimedia
17  * Lab, University of Zurich nor the names of its contributors may be used to
18  * endorse or promote products derived from this software without specific prior
19  * written permission.
20  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
24  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
25  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
26  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
27  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
28  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
29  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
30  * POSSIBILITY OF SUCH DAMAGE.
31  */
32 #ifndef __VMML__VMMLIB_LAPACK_SVD__HPP__
33 #define __VMML__VMMLIB_LAPACK_SVD__HPP__
34 
35 #include <vmmlib/matrix.hpp>
36 #include <vmmlib/vector.hpp>
37 #include <vmmlib/exception.hpp>
38 
39 #include <vmmlib/lapack_types.hpp>
40 #include <vmmlib/lapack_includes.hpp>
41 
42 #include <string>
43 
66 namespace vmml
67 {
68 
69 namespace lapack
70 {
71 
72 // XYYZZZ
73 // X = data type: S - float, D - double
74 // YY = matrix type, GE - general, TR - triangular
75 // ZZZ = function name
76 
77 
78 template< typename float_t >
79 struct svd_params
80 {
81  char jobu;
82  char jobvt;
83  lapack_int m;
84  lapack_int n;
85  float_t* a;
86  lapack_int lda;
87  float_t* s;
88  float_t* u;
89  lapack_int ldu;
90  float_t* vt;
91  lapack_int ldvt;
92  float_t* work;
93  lapack_int lwork;
94  lapack_int info;
95 
96  friend std::ostream& operator << ( std::ostream& os,
97  const svd_params< float_t >& p )
98  {
99  os
100  << "jobu " << p.jobu
101  << " jobvt " << p.jobvt
102  << " m " << p.m
103  << " n " << p.n
104  << " lda " << p.lda
105  << " ldu " << p.ldu
106  << " ldvt " << p.ldvt
107  << " lwork " << p.lwork
108  << " info " << p.info
109  << std::endl;
110  return os;
111  }
112 
113 };
114 
115 
116 #if 0
117 /* Subroutine */ int dgesvd_(char *jobu, char *jobvt, integer *m, integer *n,
118  doublereal *a, integer *lda, doublereal *s, doublereal *u, integer *
119  ldu, doublereal *vt, integer *ldvt, doublereal *work, integer *lwork,
120  integer *info);
121 #endif
122 
123 
124 template< typename float_t >
125 inline void svd_call( svd_params< float_t >& )
126 {
127  VMMLIB_ERROR( "not implemented for this type.", VMMLIB_HERE );
128 }
129 
130 
131 template<>
132 inline void
133 svd_call( svd_params< float >& p )
134 {
135  //std::cout << "calling lapack svd (single precision) " << std::endl;
136  sgesvd_(
137  &p.jobu,
138  &p.jobvt,
139  &p.m,
140  &p.n,
141  p.a,
142  &p.lda,
143  p.s,
144  p.u,
145  &p.ldu,
146  p.vt,
147  &p.ldvt,
148  p.work,
149  &p.lwork,
150  &p.info
151  );
152 
153 }
154 
155 
156 template<>
157 inline void
158 svd_call( svd_params< double >& p )
159 {
160  //std::cout << "calling lapack svd (double precision) " << std::endl;
161  dgesvd_(
162  &p.jobu,
163  &p.jobvt,
164  &p.m,
165  &p.n,
166  p.a,
167  &p.lda,
168  p.s,
169  p.u,
170  &p.ldu,
171  p.vt,
172  &p.ldvt,
173  p.work,
174  &p.lwork,
175  &p.info
176  );
177 }
178 
179 } // namespace lapack
180 
181 
182 
183 template< size_t M, size_t N, typename float_t >
185 {
186  lapack_svd();
187  ~lapack_svd();
188 
189  // slow version, full SVD, use if all values of U(MXM) and Vt(NXN) are needed
190  bool compute_full(
191  const matrix< M, N, float_t >& A,
193  vector< N, float_t >& sigma,
195  );
196 
197  // version of reduced SVD, computes only most significant left and right singular vectors,
198  // i.e., use if U(MXN) and Vt(NXN) are needed
199  bool compute(
200  const matrix< M, N, float_t >& A,
202  vector< N, float_t >& sigma,
204  );
205 
206  // overwrites A with the result U,
207  bool compute_and_overwrite_input(
209  vector< N, float_t >& sigma
210  );
211 
212  // fast version, use if only sigma is needed.
213  bool compute(
214  const matrix< M, N, float_t >& A,
215  vector< N, float_t >& sigma
216  );
217 
218  inline bool test_success( lapack::lapack_int info );
219 
221 
222  const lapack::svd_params< float_t >& get_params(){ return p; };
223 
224 }; // struct lapack_svd
225 
226 
227 template< size_t M, size_t N, typename float_t >
229 {
230  p.jobu = 'N';
231  p.jobvt = 'N';
232  p.m = M;
233  p.n = N;
234  p.a = 0;
235  p.lda = M;
236  p.s = 0;
237  p.u = 0;
238  p.ldu = M;
239  p.vt = 0;
240  p.ldvt = 1;
241  p.work = new float_t;
242  p.lwork = -1;
243 
244  // workspace query
245  lapack::svd_call( p );
246 
247  p.lwork = static_cast< lapack::lapack_int >( p.work[0] );
248  delete p.work;
249 
250  p.work = new float_t[ p.lwork ];
251 
252 }
253 
254 
255 
256 template< size_t M, size_t N, typename float_t >
257 lapack_svd< M, N, float_t >::~lapack_svd()
258 {
259  delete[] p.work;
260 }
261 
262 
263 
264 template< size_t M, size_t N, typename float_t >
265 bool
266 lapack_svd< M, N, float_t >::compute_full(
267  const matrix< M, N, float_t >& A,
268  matrix< M, M, float_t >& U,
269  vector< N, float_t >& S,
270  matrix< N, N, float_t >& Vt
271  )
272 {
273  // lapack destroys the contents of the input matrix
274  typedef matrix< M, N, float_t > m_type;
275  m_type* AA = new m_type( A );
276 
277  p.jobu = 'A';
278  p.jobvt = 'A';
279  p.a = AA->array;
280  p.u = U.array;
281  p.s = S.array;
282  p.vt = Vt.array;
283  p.ldvt = N;
284 
285  lapack::svd_call< float_t >( p );
286 
287  delete AA;
288 
289  return p.info == 0;
290 }
291 
292 template< size_t M, size_t N, typename float_t >
293 bool
294 lapack_svd< M, N, float_t >::compute(
295  const matrix< M, N, float_t >& A,
297  vector< N, float_t >& S,
298  matrix< N, N, float_t >& Vt
299  )
300 {
301  // lapack destroys the contents of the input matrix
302  typedef matrix< M, N, float_t > m_type;
303  m_type* AA = new m_type( A );
304 
305  p.jobu = 'S';
306  p.jobvt = 'S';
307  p.a = AA->array;
308  p.u = U.array;
309  p.s = S.array;
310  p.vt = Vt.array;
311  p.ldvt = N;
312 
313  lapack::svd_call< float_t >( p );
314  delete AA;
315  return p.info == 0;
316 }
317 
318 
319 template< size_t M, size_t N, typename float_t >
320 bool
321 lapack_svd< M, N, float_t >::compute_and_overwrite_input(
323  vector< N, float_t >& S
324  )
325 {
326  p.jobu = 'O';
327  p.jobvt = 'N';
328  p.a = A_U.array;
329  p.s = S.array;
330  p.ldvt = N;
331 
332  lapack::svd_call< float_t >( p );
333 
334  return p.info == 0;
335 }
336 
337 template< size_t M, size_t N, typename float_t >
338 
339 bool lapack_svd< M, N, float_t >::compute( const matrix< M, N, float_t >& A,
340 
341  vector< N, float_t >& S )
342 
343 {
344 
345  // lapack destroys the contents of the input matrix
346 
347  typedef matrix< M, N, float_t > m_type;
348 
349 m_type* AA = new m_type( A );
350 
351  p.jobu = 'N';
352  p.jobvt = 'N';
353  p.a = AA->array;
354  p.u = 0;
355  p.s = S.array;
356  p.vt = 0;
357 
358  lapack::svd_call< float_t >( p );
359 
360  delete AA;
361 
362  return p.info == 0;
363 
364 }
365 
366 
367 
368 } // namespace vmml
369 
370 #endif