Skip to content

Commit 8e774d2

Browse files
committed
matrix exp, row reduce
1 parent 4f3e6c7 commit 8e774d2

17 files changed

+210
-78
lines changed

src/main/java/ch/ethz/idsc/tensor/Unprotect.java

+11-4
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,19 @@ public static Tensor byRef(Tensor... tensors) {
3838
* @return
3939
* @throws Exception if tensor is a scalar, or first level entries do not have regular length */
4040
public static int dimension1(Tensor tensor) {
41+
int length = dimension1Hint(tensor);
42+
if (tensor.stream().skip(1).allMatch(entry -> entry.length() == length))
43+
return length;
44+
throw TensorRuntimeException.of(tensor);
45+
}
46+
47+
/** @param tensor
48+
* @return
49+
* @throws Exception if tensor is a scalar */
50+
public static int dimension1Hint(Tensor tensor) {
4151
TensorImpl impl = (TensorImpl) tensor;
4252
List<Tensor> list = impl.list;
43-
int length = list.get(0).length();
44-
if (list.stream().skip(1).anyMatch(entry -> entry.length() != length))
45-
throw TensorRuntimeException.of(tensor);
46-
return length;
53+
return list.get(0).length();
4754
}
4855

4956
/** THE USE OF THIS FUNCTION IN THE APPLICATION LAYER IS NOT RECOMMENDED !

src/main/java/ch/ethz/idsc/tensor/lie/MatrixExp.java

+7-12
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import ch.ethz.idsc.tensor.Tensor;
88
import ch.ethz.idsc.tensor.TensorRuntimeException;
99
import ch.ethz.idsc.tensor.api.ScalarUnaryOperator;
10-
import ch.ethz.idsc.tensor.sca.Abs;
1110
import ch.ethz.idsc.tensor.sca.Ceiling;
1211
import ch.ethz.idsc.tensor.sca.Chop;
1312
import ch.ethz.idsc.tensor.sca.Exp;
@@ -30,27 +29,23 @@ public enum MatrixExp {
3029
* @return exponential of given matrix exp(m) = I + m + m^2/2 + m^3/6 + ...
3130
* @throws Exception if given matrix is not a square matrix */
3231
public static Tensor of(Tensor matrix) {
33-
// LONGTERM the infinity norm is recommended
34-
Scalar max = RealScalar.of(matrix.flatten(1) //
35-
.map(Scalar.class::cast) //
36-
.map(Abs.FUNCTION) //
37-
.map(Scalar::number) //
38-
.mapToDouble(Number::doubleValue) //
39-
.reduce(Math::max) //
40-
.getAsDouble() + 1);
41-
long exponent = 1 << Ceiling.FUNCTION.apply(LOG2.apply(max)).number().longValue();
32+
long exponent = exponent(Norm2Bound.ofMatrix(matrix));
4233
return MatrixPower.of(series(matrix.multiply(RationalScalar.of(1, exponent))), exponent);
4334
}
4435

36+
/** @param norm
37+
* @return power of 2 */
38+
/* package */ static long exponent(Scalar norm) {
39+
return 1 << Ceiling.FUNCTION.apply(LOG2.apply(norm.add(RealScalar.ONE))).number().longValue();
40+
}
41+
4542
/** @param matrix square
4643
* @return
4744
* @throws Exception if given matrix is non-square */
4845
/* package */ static Tensor series(Tensor matrix) {
4946
int n = matrix.length();
5047
Tensor nxt = matrix;
5148
Tensor sum = StaticHelper.IDENTITY_MATRIX.apply(n).add(nxt);
52-
if (Chop.NONE.allZero(nxt))
53-
return sum;
5449
for (int k = 2; k <= n; ++k) {
5550
nxt = nxt.dot(matrix).divide(RealScalar.of(k));
5651
sum = sum.add(nxt);

src/main/java/ch/ethz/idsc/tensor/lie/MatrixLog.java

+3-5
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,10 @@ public enum MatrixLog {
4545
* @return
4646
* @throws Exception if computation is not supported for given matrix */
4747
public static Tensor of(Tensor matrix) {
48-
int dim1 = Unprotect.dimension1(matrix);
49-
if (matrix.length() == 2)
50-
if (dim1 == 2)
51-
return MatrixLog2.of(matrix);
52-
// ---
5348
int n = matrix.length();
49+
if (n == 2 && Unprotect.dimension1(matrix) == 2)
50+
return MatrixLog2.of(matrix);
51+
// ---
5452
Tensor id = StaticHelper.IDENTITY_MATRIX.apply(n);
5553
Tensor rem = matrix.subtract(id);
5654
Deque<DenmanBeaversDet> deque = new ArrayDeque<>();

src/main/java/ch/ethz/idsc/tensor/mat/InfluenceMatrix.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import ch.ethz.idsc.tensor.RealScalar;
99
import ch.ethz.idsc.tensor.Scalar;
1010
import ch.ethz.idsc.tensor.Tensor;
11+
import ch.ethz.idsc.tensor.alg.MatrixQ;
1112
import ch.ethz.idsc.tensor.red.Diagonal;
1213
import ch.ethz.idsc.tensor.sca.Clips;
1314
import ch.ethz.idsc.tensor.sca.Sqrt;
@@ -27,7 +28,7 @@ public class InfluenceMatrix implements Serializable {
2728
/** @param design matrix
2829
* @return */
2930
public static InfluenceMatrix of(Tensor design) {
30-
return new InfluenceMatrix(Objects.requireNonNull(design));
31+
return new InfluenceMatrix(MatrixQ.require(design));
3132
}
3233

3334
/***************************************************/

src/main/java/ch/ethz/idsc/tensor/mat/PseudoInverse.java

+7-6
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public static Tensor of(Tensor matrix) {
4343
*
4444
* @param matrix
4545
* @return */
46-
public static Tensor usingCholesky(Tensor matrix) {
46+
/* package */ static Tensor usingCholesky(Tensor matrix) {
4747
int n = matrix.length();
4848
Tensor mt = ConjugateTranspose.of(matrix);
4949
int m = mt.length();
@@ -53,13 +53,14 @@ public static Tensor usingCholesky(Tensor matrix) {
5353
}
5454

5555
/***************************************************/
56-
/** Hint: computing the pseudo-inverse using the QR decomposition is generally faster
57-
* than when using the singular value decomposition.
56+
/** Hint: computing the pseudo-inverse using the QR decomposition is
57+
* possible for matrices of maximal rank, and is generally faster than
58+
* when using the singular value decomposition.
5859
*
5960
* @param matrix with maximal rank
6061
* @return pseudoinverse of given matrix
6162
* @throws Exception if matrix does not have maximal rank */
62-
public static Tensor usingQR(Tensor matrix) {
63+
/* package */ static Tensor usingQR(Tensor matrix) {
6364
return usingQR(matrix, matrix.length(), Unprotect.dimension1(matrix));
6465
}
6566

@@ -74,14 +75,14 @@ private static Tensor usingQR(Tensor matrix, int n, int m) {
7475
*
7576
* @param matrix of arbitrary dimension and rank
7677
* @return pseudoinverse of given matrix */
77-
public static Tensor usingSvd(Tensor matrix) {
78+
/* package */ static Tensor usingSvd(Tensor matrix) {
7879
return usingSvd(matrix, Tolerance.CHOP);
7980
}
8081

8182
/** @param matrix
8283
* @param chop
8384
* @return */
84-
public static Tensor usingSvd(Tensor matrix, Chop chop) {
85+
/* package */ static Tensor usingSvd(Tensor matrix, Chop chop) {
8586
return usingSvd(matrix, chop, matrix.length(), Unprotect.dimension1(matrix));
8687
}
8788

src/main/java/ch/ethz/idsc/tensor/mat/QRDecomposition.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ public interface QRDecomposition {
1717
/** householder reflections with highest numerical stability
1818
*
1919
* @param matrix of dimensions n x m
20-
* @return qr-decomposition of given matrix */
20+
* @return qr-decomposition of given matrix
21+
* @throws Exception if input is not a non-empty rectangular matrix */
2122
static QRDecomposition of(Tensor matrix) {
2223
return of(matrix, QRSignOperators.STABILITY);
2324
}

src/main/java/ch/ethz/idsc/tensor/mat/QRDecompositionImpl.java

+18-15
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
package ch.ethz.idsc.tensor.mat;
33

44
import java.io.Serializable;
5+
import java.util.concurrent.atomic.AtomicInteger;
56

67
import ch.ethz.idsc.tensor.RealScalar;
78
import ch.ethz.idsc.tensor.Scalar;
89
import ch.ethz.idsc.tensor.Scalars;
910
import ch.ethz.idsc.tensor.Tensor;
10-
import ch.ethz.idsc.tensor.Tensors;
1111
import ch.ethz.idsc.tensor.Unprotect;
12+
import ch.ethz.idsc.tensor.ext.Integers;
1213
import ch.ethz.idsc.tensor.red.Diagonal;
1314
import ch.ethz.idsc.tensor.red.Norm;
1415
import ch.ethz.idsc.tensor.red.Times;
@@ -18,29 +19,29 @@
1819
* householder with even number of reflections
1920
* reproduces example on wikipedia */
2021
/* package */ class QRDecompositionImpl implements QRDecomposition, Serializable {
21-
private static final long serialVersionUID = 3564186473851271309L;
22+
private static final long serialVersionUID = -4880290968594939778L;
2223
// ---
23-
private final int n;
2424
private final int m;
25-
private Tensor R;
26-
private Tensor Qinv;
25+
private final Tensor R;
26+
private final Tensor Qinv;
2727

2828
/** @param matrix n x m
2929
* @param b is rhs, for instance IdentityMatrix[n]
3030
* @param qrSignOperator
3131
* @throws Exception if input is not a matrix */
3232
public QRDecompositionImpl(Tensor matrix, Tensor b, QRSignOperator qrSignOperator) {
33-
n = matrix.length();
34-
m = Unprotect.dimension1(matrix);
35-
R = matrix;
36-
Qinv = b;
37-
// the m-th reflection is necessary in the case where A is non-square
38-
for (int k = 0; k < m; ++k) {
39-
final int fk = k;
40-
Tensor x = Tensors.vector(i -> i < fk ? R.Get(i, fk).zero() : R.get(i, fk), n);
33+
int n = matrix.length();
34+
m = Integers.requirePositive(Unprotect.dimension1Hint(matrix));
35+
Tensor R = matrix;
36+
Tensor Qinv = b;
37+
for (int k = 0; k < m; ++k) { // m reflections
38+
AtomicInteger atomicInteger = new AtomicInteger(-k);
39+
Tensor x = Tensor.of(R.get(Tensor.ALL, k).stream() // k-th column of R
40+
.map(Scalar.class::cast) //
41+
.map(scalar -> atomicInteger.getAndIncrement() < 0 ? scalar.zero() : scalar));
4142
Scalar xn = Norm._2.ofVector(x);
4243
if (Scalars.nonZero(xn)) { // else reflection reduces to identity, hopefully => det == 0
43-
Tensor signed = qrSignOperator.sign(R.Get(k, k)).multiply(xn);
44+
Tensor signed = qrSignOperator.sign(x.Get(k)).multiply(xn);
4445
x.set(signed::add, k);
4546
QRReflection qrReflection = new QRReflection(k, x);
4647
Qinv = qrReflection.forward(Qinv);
@@ -51,6 +52,8 @@ public QRDecompositionImpl(Tensor matrix, Tensor b, QRSignOperator qrSignOperato
5152
for (int k = 0; k < m; ++k)
5253
for (int i = k + 1; i < n; ++i)
5354
R.set(Tolerance.CHOP, i, k);
55+
this.R = R;
56+
this.Qinv = Qinv;
5457
}
5558

5659
@Override // from QRDecomposition
@@ -70,7 +73,7 @@ public Tensor getQ() {
7073

7174
@Override // from QRDecomposition
7275
public Scalar det() {
73-
return n == m //
76+
return R.length() == m // check if R is square
7477
? Times.pmul(Diagonal.of(R)).Get()
7578
: RealScalar.ZERO;
7679
}

src/main/java/ch/ethz/idsc/tensor/mat/RowReduce.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
import ch.ethz.idsc.tensor.Scalar;
88
import ch.ethz.idsc.tensor.Scalars;
99
import ch.ethz.idsc.tensor.Tensor;
10+
import ch.ethz.idsc.tensor.ext.Integers;
1011

1112
/** inspired by
1213
* <a href="https://reference.wolfram.com/language/ref/RowReduce.html">RowReduce</a>
1314
*
1415
* @see LinearSolve */
1516
public class RowReduce extends AbstractReduce {
1617
/** @param matrix
17-
* @return reduced row echelon form (also called row canonical form) of matrix */
18+
* @return reduced row echelon form (also called row canonical form) of matrix
19+
* @throws Exception if input is not a non-empty rectangular matrix */
1820
public static Tensor of(Tensor matrix) {
1921
return of(matrix, Pivots.ARGMAX_ABS);
2022
}
@@ -31,7 +33,9 @@ private RowReduce(Tensor matrix, Pivot pivot) {
3133
}
3234

3335
private Tensor solve() {
34-
int m = Stream.of(lhs).mapToInt(Tensor::length).max().getAsInt();
36+
int m = Integers.requirePositiveOrZero(Stream.of(lhs) //
37+
.mapToInt(Tensor::length) //
38+
.max().getAsInt());
3539
int j = 0;
3640
for (int c0 = 0; c0 < lhs.length && j < m; ++j) {
3741
swap(pivot.get(c0, j, ind, lhs), c0);

src/test/java/ch/ethz/idsc/tensor/UnprotectTest.java

+7
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ public void testDimension1() {
6767
assertTrue(Unprotect.dimension1(Array.zeros(2, 3, 4)) == 3);
6868
}
6969

70+
public void testDimension1Hint() {
71+
Tensor tensor = Tensors.fromString("{{0, 2, 3}, {0, 2, 3, 5}, {{}}}");
72+
assertEquals(Unprotect.dimension1Hint(tensor), 3);
73+
AssertFail.of(() -> Unprotect.dimension1(tensor));
74+
}
75+
7076
public void testFailEmpty() {
7177
AssertFail.of(() -> Unprotect.dimension1(Tensors.empty()));
7278
}
@@ -79,6 +85,7 @@ public void testFail1() {
7985

8086
public void testFail2() {
8187
AssertFail.of(() -> Unprotect.dimension1(RealScalar.ONE));
88+
AssertFail.of(() -> Unprotect.dimension1Hint(RealScalar.ONE));
8289
}
8390

8491
public void testReferencesScalar() {

src/test/java/ch/ethz/idsc/tensor/alg/TransposeFailTest.java

+1
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,6 @@ public void testFail() {
3838

3939
public void testFail2() {
4040
AssertFail.of(() -> Transpose.of(Tensors.fromString("{{1, 2}, {3, 4, 5}}")));
41+
AssertFail.of(() -> Transpose.of(Tensors.fromString("{{1, 2, 3}, {4, 5}}")));
4142
}
4243
}

src/test/java/ch/ethz/idsc/tensor/lie/MatrixExpTest.java

+31-27
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import ch.ethz.idsc.tensor.DoubleScalar;
99
import ch.ethz.idsc.tensor.ExactScalarQ;
1010
import ch.ethz.idsc.tensor.ExactTensorQ;
11-
import ch.ethz.idsc.tensor.RationalScalar;
1211
import ch.ethz.idsc.tensor.RealScalar;
1312
import ch.ethz.idsc.tensor.Scalar;
1413
import ch.ethz.idsc.tensor.Scalars;
@@ -24,7 +23,6 @@
2423
import ch.ethz.idsc.tensor.pdf.Distribution;
2524
import ch.ethz.idsc.tensor.pdf.NormalDistribution;
2625
import ch.ethz.idsc.tensor.pdf.RandomVariate;
27-
import ch.ethz.idsc.tensor.qty.Quantity;
2826
import ch.ethz.idsc.tensor.red.Entrywise;
2927
import ch.ethz.idsc.tensor.red.Trace;
3028
import ch.ethz.idsc.tensor.sca.Chop;
@@ -34,6 +32,13 @@
3432
public class MatrixExpTest extends TestCase {
3533
private static final Random RANDOM = new Random();
3634

35+
public void testExponents() {
36+
assertEquals(MatrixExp.exponent(RealScalar.of(0)), 1);
37+
assertEquals(MatrixExp.exponent(RealScalar.of(0.99)), 2);
38+
assertEquals(MatrixExp.exponent(RealScalar.of(1)), 2);
39+
assertEquals(MatrixExp.exponent(RealScalar.of(1.01)), 4);
40+
}
41+
3742
public void testZeros() {
3843
Tensor zeros = Array.zeros(7, 7);
3944
Tensor eye = MatrixExp.of(zeros);
@@ -104,31 +109,30 @@ public void testChallenge() {
104109
Tensor altexp = A.dot(diaexp).dot(Inverse.of(A));
105110
Chop._08.requireClose(altexp, result);
106111
}
107-
108-
public void testQuantity1() {
109-
// Mathematica can't do this :-)
110-
Scalar qs1 = Quantity.of(3, "m");
111-
Tensor mat = Tensors.of( //
112-
Tensors.of(RealScalar.ZERO, qs1), //
113-
Tensors.vector(0, 0));
114-
Tensor sol = MatrixExp.of(mat);
115-
Chop.NONE.requireClose(sol, mat.add(IdentityMatrix.of(2)));
116-
}
117-
118-
public void testQuantity2() {
119-
Scalar qs1 = Quantity.of(2, "m");
120-
Scalar qs2 = Quantity.of(3, "s");
121-
Scalar qs3 = Quantity.of(4, "m");
122-
Scalar qs4 = Quantity.of(5, "s");
123-
Tensor mat = Tensors.of( //
124-
Tensors.of(RealScalar.ZERO, qs1, qs3.multiply(qs4)), //
125-
Tensors.of(RealScalar.ZERO, RealScalar.ZERO, qs2), //
126-
Tensors.of(RealScalar.ZERO, RealScalar.ZERO, RealScalar.ZERO) //
127-
);
128-
Tensor actual = IdentityMatrix.of(3).add(mat).add(mat.dot(mat).multiply(RationalScalar.of(1, 2)));
129-
// assertEquals(MatrixExp.of(mat), actual);
130-
Chop.NONE.requireClose(MatrixExp.of(mat), actual);
131-
}
112+
// public void testQuantity1() {
113+
// // Mathematica can't do this :-)
114+
// Scalar qs1 = Quantity.of(3, "m");
115+
// Tensor mat = Tensors.of( //
116+
// Tensors.of(RealScalar.ZERO, qs1), //
117+
// Tensors.vector(0, 0));
118+
// Tensor sol = MatrixExp.of(mat);
119+
// Chop.NONE.requireClose(sol, mat.add(IdentityMatrix.of(2)));
120+
// }
121+
//
122+
// public void testQuantity2() {
123+
// Scalar qs1 = Quantity.of(2, "m");
124+
// Scalar qs2 = Quantity.of(3, "s");
125+
// Scalar qs3 = Quantity.of(4, "m");
126+
// Scalar qs4 = Quantity.of(5, "s");
127+
// Tensor mat = Tensors.of( //
128+
// Tensors.of(RealScalar.ZERO, qs1, qs3.multiply(qs4)), //
129+
// Tensors.of(RealScalar.ZERO, RealScalar.ZERO, qs2), //
130+
// Tensors.of(RealScalar.ZERO, RealScalar.ZERO, RealScalar.ZERO) //
131+
// );
132+
// Tensor actual = IdentityMatrix.of(3).add(mat).add(mat.dot(mat).multiply(RationalScalar.of(1, 2)));
133+
// // assertEquals(MatrixExp.of(mat), actual);
134+
// Chop.NONE.requireClose(MatrixExp.of(mat), actual);
135+
// }
132136

133137
public void testLarge() {
134138
// without scaling, the loop of the series requires ~300 steps

0 commit comments

Comments
 (0)