Skip to content

Commit 5dc20f1

Browse files
author
Jack Poulson
committed
Adding implementations of sparse-direct LeastSquares/Ridge/Tikhonov for the case where width(A) > height(A).
1 parent 2cd9eb3 commit 5dc20f1

File tree

3 files changed

+85
-28
lines changed

3 files changed

+85
-28
lines changed

src/sparse_direct/numeric/LeastSquares.cpp

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,59 @@ void LeastSquares
2424
if( orientation != NORMAL && A.Width() != Y.Height() )
2525
LogicError("Width of A and height of Y must match");
2626
)
27-
if( A.Width() > A.Height() )
28-
LogicError("LeastSquares currently assumes height(A) >= width(A)");
27+
const Int m = A.Height();
28+
const Int n = A.Width();
2929
DistSparseMatrix<F> C(A.Comm());
30+
X.SetComm( Y.Comm() );
3031
if( orientation == NORMAL )
3132
{
32-
const Int n = A.Width();
33-
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
34-
MakeHermitian( LOWER, C );
35-
X.SetComm( Y.Comm() );
3633
Zeros( X, n, Y.Width() );
37-
Multiply( ADJOINT, F(1), A, Y, F(0), X );
34+
if( m >= n )
35+
{
36+
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
37+
MakeHermitian( LOWER, C );
38+
39+
Multiply( ADJOINT, F(1), A, Y, F(0), X );
40+
HermitianSolve( C, X, ctrl );
41+
}
42+
else
43+
{
44+
Herk( LOWER, NORMAL, Base<F>(1), A, C );
45+
MakeHermitian( LOWER, C );
46+
47+
DistMultiVec<F> YCopy(Y.Comm());
48+
YCopy = Y;
49+
HermitianSolve( C, YCopy, ctrl );
50+
Multiply( ADJOINT, F(1), A, YCopy, F(0), X );
51+
}
3852
}
3953
else if( orientation == ADJOINT || !IsComplex<F>::val )
4054
{
41-
const Int n = A.Height();
42-
Herk( LOWER, NORMAL, Base<F>(1), A, C );
43-
MakeHermitian( LOWER, C );
44-
X.SetComm( Y.Comm() );
45-
Zeros( X, n, Y.Width() );
46-
Multiply( NORMAL, F(1), A, Y, F(0), X );
55+
Zeros( X, m, Y.Width() );
56+
if( m >= n )
57+
{
58+
Herk( LOWER, NORMAL, Base<F>(1), A, C );
59+
MakeHermitian( LOWER, C );
60+
61+
Multiply( NORMAL, F(1), A, Y, F(0), X );
62+
HermitianSolve( C, X, ctrl );
63+
}
64+
else
65+
{
66+
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
67+
MakeHermitian( LOWER, C );
68+
69+
DistMultiVec<F> YCopy(Y.Comm());
70+
YCopy = Y;
71+
HermitianSolve( C, YCopy, ctrl );
72+
Multiply( NORMAL, F(1), A, YCopy, F(0), X );
73+
}
4774
}
4875
else
4976
{
5077
LogicError("Complex transposed option not yet supported");
5178
}
52-
HermitianSolve( C, X, ctrl );
79+
5380
}
5481

5582
#define PROTO(F) \

src/sparse_direct/numeric/Ridge.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,32 @@ void Ridge
2121
if( A.Height() != Y.Height() )
2222
LogicError("Heights of A and Y must match");
2323
)
24-
if( A.Width() > A.Height() )
25-
LogicError("Ridge currently assumes height(A) >= width(A)");
24+
const Int m = A.Height();
2625
const Int n = A.Width();
2726
DistSparseMatrix<F> C(A.Comm());
28-
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
29-
UpdateDiagonal( C, F(alpha*alpha) );
30-
MakeHermitian( LOWER, C );
27+
3128
X.SetComm( Y.Comm() );
3229
Zeros( X, n, Y.Width() );
33-
Multiply( ADJOINT, F(1), A, Y, F(0), X );
34-
HermitianSolve( C, X, ctrl );
30+
if( m >= n )
31+
{
32+
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
33+
UpdateDiagonal( C, F(alpha*alpha) );
34+
MakeHermitian( LOWER, C );
35+
36+
Multiply( ADJOINT, F(1), A, Y, F(0), X );
37+
HermitianSolve( C, X, ctrl );
38+
}
39+
else
40+
{
41+
Herk( LOWER, NORMAL, Base<F>(1), A, C );
42+
UpdateDiagonal( C, F(alpha*alpha) );
43+
MakeHermitian( LOWER, C );
44+
45+
DistMultiVec<F> YCopy(Y.Comm());
46+
YCopy = Y;
47+
HermitianSolve( C, YCopy, ctrl );
48+
Multiply( ADJOINT, F(1), A, YCopy, F(0), X );
49+
}
3550
}
3651

3752
#define PROTO(F) \

src/sparse_direct/numeric/Tikhonov.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,32 @@ void Tikhonov
2121
if( A.Height() != Y.Height() )
2222
LogicError("Heights of A and Y must match");
2323
)
24-
if( A.Width() > A.Height() )
25-
LogicError("Tikhonov currently assumes height(A) >= width(A)");
24+
const Int m = A.Height();
2625
const Int n = A.Width();
2726
DistSparseMatrix<F> C(A.Comm());
28-
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
29-
Herk( LOWER, ADJOINT, Base<F>(1), Gamma, Base<F>(1), C );
30-
MakeHermitian( LOWER, C );
27+
3128
X.SetComm( Y.Comm() );
3229
Zeros( X, n, Y.Width() );
33-
Multiply( ADJOINT, F(1), A, Y, F(0), X );
34-
HermitianSolve( C, X, ctrl );
30+
if( m >= n )
31+
{
32+
Herk( LOWER, ADJOINT, Base<F>(1), A, C );
33+
Herk( LOWER, ADJOINT, Base<F>(1), Gamma, Base<F>(1), C );
34+
MakeHermitian( LOWER, C );
35+
36+
Multiply( ADJOINT, F(1), A, Y, F(0), X );
37+
HermitianSolve( C, X, ctrl );
38+
}
39+
else
40+
{
41+
Herk( LOWER, NORMAL, Base<F>(1), A, C );
42+
Herk( LOWER, NORMAL, Base<F>(1), Gamma, Base<F>(1), C );
43+
MakeHermitian( LOWER, C );
44+
45+
DistMultiVec<F> YCopy(Y.Comm());
46+
YCopy = Y;
47+
HermitianSolve( C, YCopy, ctrl );
48+
Multiply( ADJOINT, F(1), A, YCopy, F(0), X );
49+
}
3550
}
3651

3752
#define PROTO(F) \

0 commit comments

Comments
 (0)