diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index 4ea6886e2a..ae2d391616 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -19,15 +19,19 @@ * Linear algebra ops. */ +import {concat, linalg, solve, sqrt, transpose} from '..'; import {ENV} from '../environment'; import {dispose} from '../globals'; import {Tensor, Tensor1D, Tensor2D} from '../tensor'; import {assert} from '../util'; + import {eye, split, squeeze, stack, unstack} from './array_ops'; import {norm} from './norm'; import {op} from './operation'; import {sum} from './reduction_ops'; -import {tensor2d} from './tensor_ops'; +import {tensor1d, tensor2d} from './tensor_ops'; + + /** * Gram-Schmidt orthogonalization. @@ -227,5 +231,79 @@ function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] { }) as [Tensor2D, Tensor2D]; } +/** + * + * @param m matrix whose svd is to be computed + * @returns + * `u`: orthogonal matrix + * + * `s`: diagonal matrix + * + * `v`: orthogonal matrix + * + * such that `m = u*s*v` + * + */ +function svd_(m: Tensor2D): {u: Tensor, s: Tensor, v: Tensor} { + const mT = m.dot(transpose(m)); + const u = eingen(mT).vectors; + // transform a tensor1d to a diagonal matrix + // where the tensor1d elements are in the diagonal + const s = concat(unstack(eingen(mT).values).reduce((a, b, i, ar) => { + const row = Array.from( + {length: ar.length}, + (e, index) => index === i ? sqrt(b.as1D()) : tensor1d([0])); + a.push(...row); + return a; + }, [])).reshape(m.shape); + const v: Tensor2D = solve(u.dot(s) as Tensor2D, m as Tensor2D) as Tensor2D; + return {u, s, v}; +} + +/** + * The algorithm used is the QR decomposition + * + * Implementation based on: + * [http://www.math.tamu.edu/~dallen/linear_algebra/chpt6.pdf] + * (http://www.math.tamu.edu/~dallen/linear_algebra/chpt6.pdf): + * + * - `A0 = Q0 * R0` + * - `A1 = R0 * Q0 = Q1 * R1` + * - `A2 = R1 * Q1 = Q2 * R2` + * - . = . * . = . * . + * - . = . * . = . * . + * - . = . * . = . * . + * - `An = Rn * Qn = Qn * Rn` + * + * `An` tends to a diagonal matrix where the diagonal values are the + * eingen values of A. + * + * Π(Q0, Q1, ..., Qn) gives the eigen vectors associated to the eigen values + * + * @param m matrix whose eigen values and vectors are to compute + * @returns + * {values: eigen values as Tensor1D, vectors: eigen vectors as Tensor} + */ +function eingen_(m: Tensor): {values: Tensor1D, vectors: Tensor} { + let z; + for (let i = 0; i < 5; i++) { + const [x, y] = linalg.qr(m); + m = y.dot(x); + z = z ? z.dot(x) : x; + y.dispose(); + } + return {values: diagonalElements(m), vectors: z}; +} + +function diagonalElements(m: Tensor): Tensor1D { + const ei: Tensor1D[] = []; + for (let i = 0; i < m.shape[0]; i++) { + ei.push(m.slice([i, i], [1, 1]).as1D()); + } + return concat(ei); +} + export const gramSchmidt = op({gramSchmidt_}); export const qr = op({qr_}); +export const eingen = op({eingen_}); +export const svd = op({svd_}); diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index 7480d7fcee..a59a07343d 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -20,6 +20,7 @@ import {describeWithFlags} from '../jasmine_util'; import {Tensor1D, Tensor2D} from '../tensor'; import {ALL_ENVS, expectArraysClose, WEBGL_ENVS} from '../test_util'; +import {svd} from './linalg_ops'; import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops'; describeWithFlags('gramSchmidt-tiny', ALL_ENVS, () => { @@ -241,3 +242,27 @@ describeWithFlags('qr', ALL_ENVS, () => { expect(() => tf.linalg.qr(x2)).toThrowError(/rank >= 2.*got rank 1/); }); }); + +describeWithFlags('svd', ALL_ENVS, () => { + it('m = u*s*v ', () => { + const m = tf.tensor2d([2, -1, 0, -1, 2, -1, 0, -1, 2], [3, 3]); + const {u, s, v} = svd(m); + expectArraysClose(u.dot(s).dot(v), m); + }); + it('orthogonal u', () => { + const m = tf.tensor2d([1, 2, 0, 0, 3, 0, 2, -4, 2], [3, 3]); + const {u} = svd(m); + const [a, b, c] = tf.unstack(u); + expectArraysClose(a.dot(b), tf.scalar(0)); + expectArraysClose(a.dot(c), tf.scalar(0)); + expectArraysClose(b.dot(c), tf.scalar(0)); + }); + it('orthogonal v', () => { + const m = tf.tensor2d([1, 2, 0, 0, 3, 0, 2, -4, 2], [3, 3]); + const {v} = svd(m); + const [a, b, c] = tf.unstack(v); + expectArraysClose(a.dot(b), tf.scalar(0)); + expectArraysClose(a.dot(c), tf.scalar(0)); + expectArraysClose(b.dot(c), tf.scalar(0)); + }); +}); diff --git a/src/ops/linalg_solve.ts b/src/ops/linalg_solve.ts new file mode 100644 index 0000000000..3a8e413193 --- /dev/null +++ b/src/ops/linalg_solve.ts @@ -0,0 +1,192 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Linear algebra resolution. + */ + +import {scalar, split, tensor1d} from '..'; +import {ENV} from '../environment'; +import {Scalar, Tensor, Tensor2D} from '../tensor'; +import {assert} from '../util'; + +import {eye, stack, unstack} from './array_ops'; +import {op} from './operation'; + +/** + * + * @param a is a square matrix M (Tensor2d with shape `[r, c]` such that `r === + * c`) + * @param b is a Tensor2d with shape `[r2, c2]` + * @desc `r === r2` + * @returns a matrix of shape `[r, c+c2]` after a [jordan-gauss + * elimination](https://en.wikipedia.org/wiki/Gaussian_elimination) on the + * matrix given by the concatenation `[a, b]`. The first r or c columns is an + * upper triangular matrix + */ +function gaussJordanTriangular( + a: Tensor2D, b: Tensor2D): {upperM: Tensor2D, det: Scalar} { + const [r, c] = a.shape; + const [r2, c2] = b.shape; + assert(r === r2, 'Second dimension size does not match'); + let inv: Tensor = a.concat(b, 1); + const rows = Array.from({length: r}, (v, i) => i); + let coef = scalar(1); + for (let i = 0; i < r; i++) { + ({inv, coef} = ENV.engine.tidy(() => { + for (let j = i + 1; j < r; j++) { + const elt = inv.slice([j, i], [1, 1]).as1D().asScalar(); + const pivot = inv.slice([i, i], [1, 1]).as1D().asScalar(); + if (elt.dataSync()[0] !== 0) { + const factor = pivot.div(elt); + coef = coef.mul(factor).mul(scalar(-1)); + const newrow = + inv.gather(tensor1d([i], 'int32')) + .sub(inv.gather(tensor1d([j], 'int32')).mul(factor)) + .as1D(); + const sli = inv.gather(tensor1d(rows.filter(e => e !== j), 'int32')); + const arr: Tensor[] = []; + if (j === 0) { + arr.push(newrow); + } + unstack(sli).forEach((t, ind) => { + if (ind !== j) { + arr.push(t); + } else { + arr.push(newrow); + arr.push(t); + } + }); + if (j === r - 1) { + arr.push(newrow); + } + inv = stack(arr); + } + } + // the first c colomns of inv is an upper triangular matrix + return {inv, coef}; + })); + } + const determinant = + diagonalMul(split(inv, [c, c2], 1)[0] as Tensor2D).div(coef).asScalar(); + return {upperM: inv as Tensor2D, det: determinant}; +} + +/** + * + * @param m Tensor2d or matrix + * @returns the product of the diagonal elements of @param m as a `tf.scalar` + */ +function diagonalMul(m: Tensor2D): Scalar { + const [r, c] = m.shape; + assert(r === c, 'Input is not a square matrix'); + let mul = m.slice([0, 0], [1, 1]).as1D().asScalar(); + for (let i = 0; i < r; i++) { + mul = m.slice([i, i], [1, 1]).as1D().asScalar(); + } + return mul; +} + +/** + * + * @param a is a unique square matrix M or a tensor of shape `[..., M, M]` whose + * inner-most 2 dimensions form square matrices + * @param b is a unique matrix M or a tensor of shape `[..., M, M]` + * @param adjoint is a boolean. + * If adjoint is false then each output matrix satisfies `a * output = b` + * (respectively `a[..., :, :] * output[..., :, :] = b[..., :, :]` if the inputs + * are arrays of matrixes) . If adjoint is true then each output matrix + * satisfies `adjoint(a) * output = b` (respectively `adjoint(a[..., :, :]) * + * output[..., :, :] = b[..., :, + * :]`). + */ +function solve_( + a: Tensor2D[]|Tensor2D, b: Tensor2D[]|Tensor2D, + adjoint = false): Tensor2D[]|Tensor2D { + if (Array.isArray(a) || Array.isArray(b)) { + assert( + (a as Tensor2D[]).length === (b as Tensor2D[]).length, + 'Second dimension size does not match'); + const sol: Tensor2D[] = []; + (a as Tensor2D[]).forEach((m, i) => { + sol.push(solve_unique_equation(m, (b as Tensor2D[])[i], adjoint)); + }); + return sol; + } else { + return solve_unique_equation(a, b, adjoint); + } +} + +// helper to the solve equation +function solve_unique_equation( + a: Tensor2D, b: Tensor2D, adjoint = false): Tensor2D { + return ENV.engine.tidy(() => { + const [r, c] = a.shape; + const [r2, c2] = b.shape; + assert(r === r2, 'Second dimension size does not match'); + if (adjoint) { + a = adjM(a); + } + const {upperM, det} = gaussJordanTriangular(a, b); + assert(det.dataSync()[0] !== 0, 'Input matrix is not inversible'); + const trian = unstack(upperM); + const len = trian.length; + trian[len - 1] = + trian[len - 1].div(trian[len - 1].slice(r - 1, 1).asScalar()); + for (let i = r - 2; i > -1; i--) { + for (let j = r - 1; j > i; j--) { + trian[i] = trian[i].sub(trian[j].mul(trian[i].slice(j, 1).asScalar())); + } + trian[i] = trian[i].div(trian[i].slice(i, 1).asScalar()); + } + return split(stack(trian), [c, c2], 1)[1] as Tensor2D; + }); +} + +/** + * + * @param x square matrix to invert + * @returns the invert matrix of @param m if inversible + */ +function invertMatrix_(m: Tensor2D): Tensor2D { + const [r, c] = m.shape; + assert(r === c, 'Input is not a square matrix'); + return solve(m, eye(r)) as Tensor2D; +} + +/** + * + * @param m Tensor2d or matrix + * @returns the determinant of @param m as a `tf.scalar` + */ +function det_(m: Tensor2D): Scalar { + return gaussJordanTriangular(m, eye(m.shape[0]) as Tensor2D).det; +} + +/** + * + * @param m Tensor2d or matrix + * @returns the adjoint of @param m if inversible + */ +function adjointM_(m: Tensor2D): Tensor2D { + return invertMatrix(m).mul(det(m)); +} + +export const solve = op({solve_}); +export const invertMatrix = op({invertMatrix_}); +export const adjM = op({adjointM_}); +export const det = op({det_}); diff --git a/src/ops/linalg_solve_test.ts b/src/ops/linalg_solve_test.ts new file mode 100644 index 0000000000..08875ce583 --- /dev/null +++ b/src/ops/linalg_solve_test.ts @@ -0,0 +1,69 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {concat} from '../index'; +import {describeWithFlags} from '../jasmine_util'; +import {ALL_ENVS, expectArraysClose} from '../test_util'; + +import {adjM, invertMatrix} from './linalg_solve'; + +describeWithFlags('solve_linear', ALL_ENVS, () => { + it('solve equation with a: the matrix eye', () => { + const a = tf.eye(3); + const b = tf.tensor2d([1, 2, 3], [3, 1]); + const x = tf.solve(a, b) as tf.Tensor2D; + expect(() => expectArraysClose(x, [1, 2, 3])); + }); + + it('solve equation with a: an array of matrix eye', () => { + const a = [tf.eye(3), tf.eye(3), tf.eye(3)]; + const b = Array.from({length: 3}, (v, i) => tf.tensor2d([1, 2, 3], [3, 1])); + const x = concat(tf.solve(a, b) as tf.Tensor2D[]); + (tf.solve(a, b) as tf.Tensor2D[]).forEach(e => e.print()); + expect(() => expectArraysClose(x, [1, 2, 3])); + }); + + it('solve equation with a the matrix eye times 2', () => { + const a = tf.eye(3).mul(tf.scalar(2)) as tf.Tensor2D; + const b = tf.tensor2d([1, 2, 3], [3, 1]); + const x = tf.solve(a, b) as tf.Tensor2D; + expect(() => expectArraysClose(x, [1, 2, 3])); + }); + + it('should throw error if a is not inversible', () => { + const a = tf.ones([3, 3]) as tf.Tensor2D; + const b = tf.tensor2d([1, 2, 3], [3, 1]); + expect(() => tf.solve(a, b)).toThrowError('Input matrix is not inversible'); + }); +}); + +describeWithFlags('invert_matrix', ALL_ENVS, () => { + it('invert a matrix', () => { + const m = tf.tensor2d([1, 3, 3, 1, 4, 3, 1, 3, 4], [3, 3]); + const inv = invertMatrix(m); + expect(() => expectArraysClose(inv, [7, -3, -3, -1, 1, 0, -1, 0, 1])); + }); +}); + +describeWithFlags('adjoint_matrix', ALL_ENVS, () => { + it('computes the adjoint of a matrix', () => { + const m = tf.tensor2d([1, 3, 3, 1, 4, 3, 1, 3, 4], [3, 3]); + const a = adjM(m); + expect(() => expectArraysClose(a, [7, -3, -3, -1, 1, 0, -1, 0, 1])); + }); +}); diff --git a/src/ops/ops.ts b/src/ops/ops.ts index 2c9be6f47f..3f88d4f812 100644 --- a/src/ops/ops.ts +++ b/src/ops/ops.ts @@ -39,6 +39,7 @@ export * from './lstm'; export * from './moving_average'; export * from './strided_slice'; export * from './topk'; +export * from './linalg_solve'; export {op} from './operation';