@@ -105,7 +105,7 @@ static ffi::Error SvdOnlyVtImpl(
105
105
106
106
MachineType* u_data;
107
107
MachineType* vt_data;
108
- if (mode == UVtMode::computeOnlyU && x_rows < x_cols) {
108
+ if (( mode == UVtMode::computeOnlyU || mode == UVtMode::computePartialUandVt) && x_rows < x_cols) {
109
109
u_data = u_or_vt_data;
110
110
vt_data = nullptr ;
111
111
} else {
@@ -122,7 +122,7 @@ static ffi::Error SvdOnlyVtImpl(
122
122
const char jobz = ' O' ;
123
123
lapack_int ldu;
124
124
lapack_int ldvt;
125
- if (mode == UVtMode::computeOnlyU && x_rows < x_cols) {
125
+ if (( mode == UVtMode::computeOnlyU || mode == UVtMode::computePartialUandVt) && x_rows < x_cols) {
126
126
ldu = x_rows_lapack;
127
127
ldvt = 1 ;
128
128
} else {
@@ -193,6 +193,7 @@ static ffi::Error SvdOnlyVtQRImpl(
193
193
ffi::Buffer<dtype> x,
194
194
ffi::ResultBuffer<dtype> x_out,
195
195
ffi::ResultBuffer<ffi::ToReal(dtype)> s,
196
+ ffi::ResultBuffer<dtype> u_or_vt,
196
197
ffi::ResultBuffer<ffi::DataType::S32> info,
197
198
UVtMode mode) {
198
199
@@ -275,9 +276,12 @@ static ffi::Error SvdOnlyVtQRImpl(
275
276
276
277
auto * x_out_data = x_out->typed_data ();
277
278
auto * s_data = s->typed_data ();
278
- // auto* vt_data = vt ->typed_data();
279
+ auto * u_or_vt_data = u_or_vt ->typed_data ();
279
280
auto * info_data = info->typed_data ();
280
281
282
+ MachineType* u_data;
283
+ MachineType* vt_data;
284
+
281
285
if (x.typed_data () != x_out_data) {
282
286
std::copy_n (x.typed_data (), x.element_count (), x_out_data);
283
287
}
@@ -287,18 +291,38 @@ static ffi::Error SvdOnlyVtQRImpl(
287
291
288
292
char jobu;
289
293
char jobvt;
290
- const lapack_int ldu = 1 ;
291
- const lapack_int ldvt = 1 ;
294
+ lapack_int ldu;
295
+ lapack_int ldvt;
292
296
if (mode == UVtMode::computeOnlyU) {
293
297
jobu = ' O' ;
294
298
jobvt = ' N' ;
295
- // ldu = 1;
296
- // ldvt = 1;
297
- } else {
299
+ ldu = 1 ;
300
+ ldvt = 1 ;
301
+ u_data = nullptr ;
302
+ vt_data = nullptr ;
303
+ } else if (mode == UVtMode::computeOnlyVt) {
298
304
jobu = ' N' ;
299
305
jobvt = ' O' ;
300
- // ldu = 1;
301
- // ldvt = 1;
306
+ ldu = 1 ;
307
+ ldvt = 1 ;
308
+ u_data = nullptr ;
309
+ vt_data = nullptr ;
310
+ } else {
311
+ if (x_rows >= x_cols) {
312
+ jobu = ' O' ;
313
+ jobvt = ' S' ;
314
+ ldu = 1 ;
315
+ ldvt = x_cols_lapack;
316
+ u_data = nullptr ;
317
+ vt_data = u_or_vt_data;
318
+ } else {
319
+ jobu = ' S' ;
320
+ jobvt = ' O' ;
321
+ ldu = x_rows_lapack;
322
+ ldvt = 1 ;
323
+ u_data = u_or_vt_data;
324
+ vt_data = nullptr ;
325
+ }
302
326
}
303
327
304
328
if constexpr (ffi::IsComplexType<dtype>()) {
@@ -337,14 +361,14 @@ static ffi::Error SvdOnlyVtQRImpl(
337
361
338
362
if constexpr (ffi::IsComplexType<dtype>()) {
339
363
fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
340
- &x_rows_lapack, s_data, nullptr ,
341
- &ldu, nullptr , &ldvt, work.get (),
364
+ &x_rows_lapack, s_data, u_data ,
365
+ &ldu, vt_data , &ldvt, work.get (),
342
366
&lwork, rwork.get (), info_data
343
367
);
344
368
} else {
345
369
fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
346
- &x_rows_lapack, s_data, nullptr ,
347
- &ldu, nullptr , &ldvt,
370
+ &x_rows_lapack, s_data, u_data ,
371
+ &ldu, vt_data , &ldvt,
348
372
work.get (), &lwork, info_data
349
373
);
350
374
}
@@ -363,7 +387,7 @@ static ffi::Error SvdOnlyVtQRImpl(
363
387
.Arg<ffi::Buffer<dtype>>(/* x*/ ) \
364
388
.Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
365
389
.Ret<ffi::Buffer<dtype>>(/* s*/ ) \
366
- .Ret<ffi::Buffer<dtype>>(/* vt */ ) \
390
+ .Ret<ffi::Buffer<dtype>>(/* u_or_vt */ ) \
367
391
.Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
368
392
.Attr<UVtMode>(" mode" ))
369
393
@@ -374,7 +398,7 @@ static ffi::Error SvdOnlyVtQRImpl(
374
398
.Arg<ffi::Buffer<dtype>>(/* x*/ ) \
375
399
.Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
376
400
.Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/* s*/ ) \
377
- .Ret<ffi::Buffer<dtype>>(/* vt */ ) \
401
+ .Ret<ffi::Buffer<dtype>>(/* u_or_vt */ ) \
378
402
.Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
379
403
.Attr<UVtMode>(" mode" ))
380
404
@@ -390,6 +414,7 @@ DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
390
414
.Arg<ffi::Buffer<dtype>>(/* x*/ ) \
391
415
.Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
392
416
.Ret<ffi::Buffer<dtype>>(/* s*/ ) \
417
+ .Ret<ffi::Buffer<dtype>>(/* u_or_vt*/ ) \
393
418
.Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
394
419
.Attr<UVtMode>(" mode" ))
395
420
@@ -400,6 +425,7 @@ DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
400
425
.Arg<ffi::Buffer<dtype>>(/* x*/ ) \
401
426
.Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
402
427
.Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/* s*/ ) \
428
+ .Ret<ffi::Buffer<dtype>>(/* u_or_vt*/ ) \
403
429
.Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
404
430
.Attr<UVtMode>(" mode" ))
405
431
0 commit comments