Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 2ae5a8a

Browse files
jarno-rdsmilkov
authored andcommitted
Fixed division by zero in QR decomposition. Issue #1058 (#1473)
tensorflow/tfjs#1058 The sign() function returns 0 on 0, which causes a division by zero in the QR decomposition function qr() if there is a zero on the diagonal. BUG
1 parent b484b28 commit 2ae5a8a

File tree

5 files changed

+49
-30
lines changed

5 files changed

+49
-30
lines changed

src/io/passthrough_test.ts

+4-5
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ describeWithFlags('Passthrough Saver', BROWSER_ENVS, () => {
115115

116116
describeWithFlags('Passthrough Loader', BROWSER_ENVS, () => {
117117
it('load topology and weights: legacy signature', async () => {
118-
const passthroughHandler = tf.io.fromMemory(
119-
modelTopology1, weightSpecs1, weightData1);
118+
const passthroughHandler =
119+
tf.io.fromMemory(modelTopology1, weightSpecs1, weightData1);
120120
const modelArtifacts = await passthroughHandler.load();
121121
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
122122
expect(modelArtifacts.weightSpecs).toEqual(weightSpecs1);
@@ -147,9 +147,8 @@ describeWithFlags('Passthrough Loader', BROWSER_ENVS, () => {
147147
});
148148

149149
it('load model topology only', async () => {
150-
const passthroughHandler = tf.io.fromMemory({
151-
modelTopology: modelTopology1
152-
});
150+
const passthroughHandler =
151+
tf.io.fromMemory({modelTopology: modelTopology1});
153152
const modelArtifacts = await passthroughHandler.load();
154153
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
155154
expect(modelArtifacts.weightSpecs).toEqual(undefined);

src/jasmine_util.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ export const SYNC_BACKEND_ENVS: Constraints = {
4040
};
4141

4242
export const HAS_WORKER = {
43-
predicate: () => typeof(Worker) !== 'undefined'
44-
&& typeof(Blob) !== 'undefined' && typeof(URL) !== 'undefined'
43+
predicate: () => typeof (Worker) !== 'undefined' &&
44+
typeof (Blob) !== 'undefined' && typeof (URL) !== 'undefined'
4545
};
4646

4747
export const HAS_NODE_WORKER = {
@@ -52,7 +52,7 @@ export const HAS_NODE_WORKER = {
5252
} catch {
5353
hasWorker = false;
5454
}
55-
return typeof(process) !== 'undefined' && hasWorker;
55+
return typeof (process) !== 'undefined' && hasWorker;
5656
}
5757
};
5858

src/ops/concat_test.ts

+27-21
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ describeWithFlags('concat1d', ALL_ENVS, () => {
8888
expectArraysClose(await result.data(), expected);
8989
});
9090

91-
it('concat complex input', async() => {
91+
it('concat complex input', async () => {
9292
// [1+1j, 2+2j]
9393
const c1 = tf.complex([1, 2], [1, 2]);
9494
// [3+3j, 4+4j]
@@ -234,7 +234,7 @@ describeWithFlags('concat2d', ALL_ENVS, () => {
234234
expectArraysEqual(await res2.data(), []);
235235
});
236236

237-
it('concat complex input axis=0', async() => {
237+
it('concat complex input axis=0', async () => {
238238
// [[1+1j, 2+2j], [3+3j, 4+4j]]
239239
const c1 = tf.complex([[1, 2], [3, 4]], [[1, 2], [3, 4]]);
240240
// [[5+5j, 6+6j], [7+7j, 8+8j]]
@@ -247,7 +247,7 @@ describeWithFlags('concat2d', ALL_ENVS, () => {
247247
expectArraysClose(await result.data(), expected);
248248
});
249249

250-
it('concat complex input axis=1', async() => {
250+
it('concat complex input axis=1', async () => {
251251
// [[1+1j, 2+2j], [3+3j, 4+4j]]
252252
const c1 = tf.complex([[1, 2], [3, 4]], [[1, 2], [3, 4]]);
253253
// [[5+5j, 6+6j], [7+7j, 8+8j]]
@@ -500,50 +500,56 @@ describeWithFlags('concat3d', ALL_ENVS, () => {
500500
expectArraysClose(await values.data(), [1, 2, 3, 4, 5, 6]);
501501
});
502502

503-
it('concat complex input axis=0', async() => {
503+
it('concat complex input axis=0', async () => {
504504
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
505-
const c1 = tf.complex(
506-
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
505+
const c1 =
506+
tf.complex([[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
507507
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
508508
const c2 = tf.complex(
509-
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
509+
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
510510

511511
const axis = 0;
512512
const result = tf.concat([c1, c2], axis);
513-
const expected = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
514-
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12];
513+
const expected = [
514+
1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
515+
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12
516+
];
515517
expect(result.dtype).toEqual('complex64');
516518
expectArraysClose(await result.data(), expected);
517519
});
518520

519-
it('concat complex input axis=1', async() => {
521+
it('concat complex input axis=1', async () => {
520522
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
521-
const c1 = tf.complex(
522-
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
523+
const c1 =
524+
tf.complex([[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
523525
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
524526
const c2 = tf.complex(
525-
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
527+
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
526528

527529
const axis = 1;
528530
const result = tf.concat([c1, c2], axis);
529-
const expected = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
530-
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12];
531+
const expected = [
532+
1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
533+
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12
534+
];
531535
expect(result.dtype).toEqual('complex64');
532536
expectArraysClose(await result.data(), expected);
533537
});
534538

535-
it('concat complex input axis=1', async() => {
539+
it('concat complex input axis=1', async () => {
536540
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
537-
const c1 = tf.complex(
538-
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
541+
const c1 =
542+
tf.complex([[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
539543
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
540544
const c2 = tf.complex(
541-
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
545+
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
542546

543547
const axis = 2;
544548
const result = tf.concat([c1, c2], axis);
545-
const expected = [1, 1, 2, 2, 7, 7, 8, 8, 3, 3, 4, 4,
546-
9, 9, 10, 10, 5, 5, 6, 6, 11, 11, 12, 12];
549+
const expected = [
550+
1, 1, 2, 2, 7, 7, 8, 8, 3, 3, 4, 4,
551+
9, 9, 10, 10, 5, 5, 6, 6, 11, 11, 12, 12
552+
];
547553
expect(result.dtype).toEqual('complex64');
548554
expectArraysClose(await result.data(), expected);
549555
});

src/ops/linalg_ops.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,10 @@ function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] {
215215
const rjEnd1 = r.slice([j, j], [m - j, 1]);
216216
const normX = rjEnd1.norm();
217217
const rjj = r.slice([j, j], [1, 1]);
218-
const s = rjj.sign().neg() as Tensor2D;
218+
219+
// The sign() function returns 0 on 0, which causes division by zero.
220+
const s = tensor2d([[-1]]).where(rjj.greater(0), tensor2d([[1]]));
221+
219222
const u1 = rjj.sub(s.mul(normX)) as Tensor2D;
220223
const wPre = rjEnd1.div(u1);
221224
if (wPre.shape[0] === 1) {

src/ops/linalg_ops_test.ts

+11
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,17 @@ describeWithFlags('qr', ALL_ENVS, () => {
140140
[[-8.3066, 8.3066, -2.4077], [0, 4.5826, -2.1822], [0, 0, 7.6447]]);
141141
});
142142

143+
it('3x3, zero on diagonal', async () => {
144+
const x = tensor2d([[0, 2, 2], [1, 1, 1], [0, 1, 2]], [3, 3]);
145+
const [q, r] = tf.linalg.qr(x);
146+
expectArraysClose(await q.data(), [
147+
[0., -0.89442719, 0.4472136], [1., 0., 0.], [0., -0.4472136, -0.89442719]
148+
]);
149+
expectArraysClose(
150+
await r.data(),
151+
[[1., 1., 1.], [0., -2.23606798, -2.68328157], [0., 0., -0.89442719]]);
152+
});
153+
143154
it('3x2, fullMatrices = default false', async () => {
144155
const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]);
145156
const [q, r] = tf.linalg.qr(x);

0 commit comments

Comments
 (0)