Skip to content

Commit 5d51aa0

Browse files
committed
Implement differentiable version of svd_only_u and svd_only_vt along with a gauge-fixed version
1 parent 1d1fd29 commit 5d51aa0

File tree

3 files changed

+227
-94
lines changed

3 files changed

+227
-94
lines changed

varipeps/utils/extensions/svd_ffi.cpp

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ static ffi::Error SvdOnlyVtImpl(
105105

106106
MachineType* u_data;
107107
MachineType* vt_data;
108-
if (mode == UVtMode::computeOnlyU && x_rows < x_cols) {
108+
if ((mode == UVtMode::computeOnlyU || mode == UVtMode::computePartialUandVt) && x_rows < x_cols) {
109109
u_data = u_or_vt_data;
110110
vt_data = nullptr;
111111
} else {
@@ -122,7 +122,7 @@ static ffi::Error SvdOnlyVtImpl(
122122
const char jobz = 'O';
123123
lapack_int ldu;
124124
lapack_int ldvt;
125-
if (mode == UVtMode::computeOnlyU && x_rows < x_cols) {
125+
if ((mode == UVtMode::computeOnlyU || mode == UVtMode::computePartialUandVt) && x_rows < x_cols) {
126126
ldu = x_rows_lapack;
127127
ldvt = 1;
128128
} else {
@@ -193,6 +193,7 @@ static ffi::Error SvdOnlyVtQRImpl(
193193
ffi::Buffer<dtype> x,
194194
ffi::ResultBuffer<dtype> x_out,
195195
ffi::ResultBuffer<ffi::ToReal(dtype)> s,
196+
ffi::ResultBuffer<dtype> u_or_vt,
196197
ffi::ResultBuffer<ffi::DataType::S32> info,
197198
UVtMode mode) {
198199

@@ -275,9 +276,12 @@ static ffi::Error SvdOnlyVtQRImpl(
275276

276277
auto* x_out_data = x_out->typed_data();
277278
auto* s_data = s->typed_data();
278-
// auto* vt_data = vt->typed_data();
279+
auto* u_or_vt_data = u_or_vt->typed_data();
279280
auto* info_data = info->typed_data();
280281

282+
MachineType* u_data;
283+
MachineType* vt_data;
284+
281285
if (x.typed_data() != x_out_data) {
282286
std::copy_n(x.typed_data(), x.element_count(), x_out_data);
283287
}
@@ -287,18 +291,38 @@ static ffi::Error SvdOnlyVtQRImpl(
287291

288292
char jobu;
289293
char jobvt;
290-
const lapack_int ldu = 1;
291-
const lapack_int ldvt = 1;
294+
lapack_int ldu;
295+
lapack_int ldvt;
292296
if (mode == UVtMode::computeOnlyU) {
293297
jobu = 'O';
294298
jobvt = 'N';
295-
// ldu = 1;
296-
// ldvt = 1;
297-
} else {
299+
ldu = 1;
300+
ldvt = 1;
301+
u_data = nullptr;
302+
vt_data = nullptr;
303+
} else if (mode == UVtMode::computeOnlyVt) {
298304
jobu = 'N';
299305
jobvt = 'O';
300-
// ldu = 1;
301-
// ldvt = 1;
306+
ldu = 1;
307+
ldvt = 1;
308+
u_data = nullptr;
309+
vt_data = nullptr;
310+
} else {
311+
if (x_rows >= x_cols) {
312+
jobu = 'O';
313+
jobvt = 'S';
314+
ldu = 1;
315+
ldvt = x_cols_lapack;
316+
u_data = nullptr;
317+
vt_data = u_or_vt_data;
318+
} else {
319+
jobu = 'S';
320+
jobvt = 'O';
321+
ldu = x_rows_lapack;
322+
ldvt = 1;
323+
u_data = u_or_vt_data;
324+
vt_data = nullptr;
325+
}
302326
}
303327

304328
if constexpr (ffi::IsComplexType<dtype>()) {
@@ -337,14 +361,14 @@ static ffi::Error SvdOnlyVtQRImpl(
337361

338362
if constexpr (ffi::IsComplexType<dtype>()) {
339363
fn(&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
340-
&x_rows_lapack, s_data, nullptr,
341-
&ldu, nullptr, &ldvt, work.get(),
364+
&x_rows_lapack, s_data, u_data,
365+
&ldu, vt_data, &ldvt, work.get(),
342366
&lwork, rwork.get(), info_data
343367
);
344368
} else {
345369
fn(&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
346-
&x_rows_lapack, s_data, nullptr,
347-
&ldu, nullptr, &ldvt,
370+
&x_rows_lapack, s_data, u_data,
371+
&ldu, vt_data, &ldvt,
348372
work.get(), &lwork, info_data
349373
);
350374
}
@@ -363,7 +387,7 @@ static ffi::Error SvdOnlyVtQRImpl(
363387
.Arg<ffi::Buffer<dtype>>(/*x*/) \
364388
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
365389
.Ret<ffi::Buffer<dtype>>(/*s*/) \
366-
.Ret<ffi::Buffer<dtype>>(/*vt*/) \
390+
.Ret<ffi::Buffer<dtype>>(/*u_or_vt*/) \
367391
.Ret<ffi::Buffer<ffi::DataType::S32>>(/*info*/) \
368392
.Attr<UVtMode>("mode"))
369393

@@ -374,7 +398,7 @@ static ffi::Error SvdOnlyVtQRImpl(
374398
.Arg<ffi::Buffer<dtype>>(/*x*/) \
375399
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
376400
.Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/*s*/) \
377-
.Ret<ffi::Buffer<dtype>>(/*vt*/) \
401+
.Ret<ffi::Buffer<dtype>>(/*u_or_vt*/) \
378402
.Ret<ffi::Buffer<ffi::DataType::S32>>(/*info*/) \
379403
.Attr<UVtMode>("mode"))
380404

@@ -390,6 +414,7 @@ DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
390414
.Arg<ffi::Buffer<dtype>>(/*x*/) \
391415
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
392416
.Ret<ffi::Buffer<dtype>>(/*s*/) \
417+
.Ret<ffi::Buffer<dtype>>(/*u_or_vt*/) \
393418
.Ret<ffi::Buffer<ffi::DataType::S32>>(/*info*/) \
394419
.Attr<UVtMode>("mode"))
395420

@@ -400,6 +425,7 @@ DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
400425
.Arg<ffi::Buffer<dtype>>(/*x*/) \
401426
.Ret<ffi::Buffer<dtype>>(/*x_out*/) \
402427
.Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/*s*/) \
428+
.Ret<ffi::Buffer<dtype>>(/*u_or_vt*/) \
403429
.Ret<ffi::Buffer<ffi::DataType::S32>>(/*info*/) \
404430
.Attr<UVtMode>("mode"))
405431

varipeps/utils/extensions/svd_ffi.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
#include "xla/ffi/api/ffi.h"
55

66
enum class UVtMode : int8_t {
7-
computeOnlyU = 0, // Compute only U
8-
computeOnlyVt = 1, // Compute only Vt
7+
computeOnlyU = 0, // Compute only U
8+
computeOnlyVt = 1, // Compute only Vt
9+
computePartialUandVt = 2, // Compute only Vt
910
};
1011

1112
XLA_FFI_DECLARE_HANDLER_SYMBOL(svd_only_vt_f32);

0 commit comments

Comments
 (0)