Skip to content

Commit 21ee908

Browse files
authoredFeb 9, 2019
[date-conversion-attention] Initial check-in of date-conversion-attention (#212)
- This PR checks in only the training scripts. Model inference in the browser will be checked in in a later PR. - Unit tests are written for the data, model training and inference routines, although they are not hooked up with Travis right now. They are just run manually instead.
1 parent 651fd08 commit 21ee908

14 files changed

+4587
-0
lines changed
 

‎date-conversion-attention/.babelrc

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"presets": [
3+
[
4+
"env",
5+
{
6+
"esmodules": false,
7+
"targets": {
8+
"browsers": [
9+
"> 3%"
10+
]
11+
}
12+
}
13+
]
14+
],
15+
"plugins": [
16+
"transform-runtime"
17+
]
18+
}

‎date-conversion-attention/README.md

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# TensorFlow.js Example: Date Conversion Through an LSTM-Attention Model
2+
3+
## Overview
4+
5+
This example shows how to use TensorFlow.js to train a model based on
6+
long short-term memory (LSTM) and the attention mechanism to achieve
7+
a task of converting various commonly seen date formats (e.g., 01/18/2019,
8+
18JAN2019, 18-01-2019) to the ISO date format (i.e., 2019-01-18).
9+
10+
We demonstrate the full machine-learning workflow, consisting of
11+
data engineering, server-side model training, client-side inference,
12+
model visualization, and unit testing in this example.
13+
14+
The training data is synthesized programmatically.
15+
16+
## Model training in Node.js
17+
18+
For efficiency, the training of the model happens outside the browser
19+
in Node.js, using tfjs-node or tfjs-node-gpu.
20+
21+
To run the training job, do
22+
23+
```sh
24+
yarn
25+
yarn train
26+
```
27+
28+
By default, the training uses tfjs-node, which runs on the CPU.
29+
If you have a CUDA-enabled GPU and have the CUDA and CuDNN libraries
30+
set up properly on your system, you can run the training on the GPU
31+
by:
32+
33+
```sh
34+
yarn
35+
yarn train --gpu
36+
```
37+
38+
## Using the model in the browser
39+
40+
TODO(cais): Implement it.
41+
42+
### Visualization of the attention mechanism
43+
44+
TODO(cais): Implement it.
45+
46+
## Running unit tests
47+
48+
The data and model code in this example are covered by unit tests.
49+
To run the unit tests:
50+
51+
```sh
52+
cd ../
53+
yarn
54+
cd date-conversion-attention
55+
yarn
56+
yarn test
57+
```
+236
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* Date formats and conversion utility functions.
20+
*
21+
* This file is used for the training of the date-conversion model and
22+
* date conversions based on the trained model.
23+
*
24+
* It contains functions that generate random dates and represent them in
25+
* several different formats such as (2019-01-20 and 20JAN19).
26+
* It also contains functions that convert the text representation of
27+
* the dates into one-hot `tf.Tensor` representations.
28+
*/
29+
30+
const tf = require('@tensorflow/tfjs');
31+
32+
const MONTH_NAMES_FULL = [
33+
'January', 'February', 'March', 'April', 'May', 'June', 'July', 'August',
34+
'September', 'October', 'November', 'December'
35+
];
36+
const MONTH_NAMES_3LETTER =
37+
MONTH_NAMES_FULL.map(name => name.slice(0, 3).toUpperCase());
38+
39+
const MIN_DATE = new Date('1950-01-01').getTime();
40+
const MAX_DATE = new Date('2050-01-01').getTime();
41+
42+
export const INPUT_LENGTH = 12 // Maximum length of all input formats.
43+
export const OUTPUT_LENGTH = 10 // Length of 'YYYY-MM-DD'.
44+
45+
// Use "\n" for padding for both input and output. It has to be at the
46+
// beginning so that `mask_zero=True` can be used in the keras model.
47+
export const INPUT_VOCAB = '\n0123456789/-., ' +
48+
MONTH_NAMES_3LETTER.join('')
49+
.split('')
50+
.filter(function(item, i, ar) {
51+
return ar.indexOf(item) === i;
52+
})
53+
.join('');
54+
55+
// OUTPUT_VOCAB includes an start-of-sequence (SOS) token, represented as
56+
// '\t'. Note that the date strings are represented in terms of their
57+
// constituent characters, not words or anything else.
58+
export const OUTPUT_VOCAB = '\n\t0123456789-';
59+
60+
export const START_CODE = 1;
61+
62+
/**
63+
* Generate a random date.
64+
*
65+
* @return {[number, number, number]} Year as an integer, month as an
66+
* integer >= 1 and <= 12, day as an integer >= 1.
67+
*/
68+
export function generateRandomDateTuple() {
69+
const date = new Date(Math.random() * (MAX_DATE - MIN_DATE) + MIN_DATE);
70+
return [date.getFullYear(), date.getMonth() + 1, date.getDate()];
71+
}
72+
73+
function toTwoDigitString(num) {
74+
return num < 10 ? `0${num}` : `${num}`;
75+
}
76+
77+
/** Date format such as 01202019. */
78+
export function dateTupleToDDMMMYYYY(dateTuple) {
79+
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
80+
const dayStr = toTwoDigitString(dateTuple[2]);
81+
return `${dayStr}${monthStr}${dateTuple[0]}`;
82+
}
83+
84+
/** Date format such as 01/20/2019. */
85+
export function dateTupleToMMSlashDDSlashYYYY(dateTuple) {
86+
const monthStr = toTwoDigitString(dateTuple[1]);
87+
const dayStr = toTwoDigitString(dateTuple[2]);
88+
return `${monthStr}/${dayStr}/${dateTuple[0]}`;
89+
}
90+
91+
/** Date format such as 01/20/19. */
92+
export function dateTupleToMMSlashDDSlashYY(dateTuple) {
93+
const monthStr = toTwoDigitString(dateTuple[1]);
94+
const dayStr = toTwoDigitString(dateTuple[2]);
95+
const yearStr = `${dateTuple[0]}`.slice(2);
96+
return `${monthStr}/${dayStr}/${yearStr}`;
97+
}
98+
99+
/** Date format such as 012019. */
100+
export function dateTupleToMMDDYY(dateTuple) {
101+
const monthStr = toTwoDigitString(dateTuple[1]);
102+
const dayStr = toTwoDigitString(dateTuple[2]);
103+
const yearStr = `${dateTuple[0]}`.slice(2);
104+
return `${monthStr}${dayStr}${yearStr}`;
105+
}
106+
107+
/** Date format such as JAN 20 19. */
108+
export function dateTupleToMMMSpaceDDSpaceYY(dateTuple) {
109+
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
110+
const dayStr = toTwoDigitString(dateTuple[2]);
111+
const yearStr = `${dateTuple[0]}`.slice(2);
112+
return `${monthStr} ${dayStr} ${yearStr}`;
113+
}
114+
115+
/** Date format such as JAN 20 2019. */
116+
export function dateTupleToMMMSpaceDDSpaceYYYY(dateTuple) {
117+
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
118+
const dayStr = toTwoDigitString(dateTuple[2]);
119+
return `${monthStr} ${dayStr} ${dateTuple[0]}`;
120+
}
121+
122+
/** Date format such as JAN 20, 19. */
123+
export function dateTupleToMMMSpaceDDCommaSpaceYY(dateTuple) {
124+
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
125+
const dayStr = toTwoDigitString(dateTuple[2]);
126+
const yearStr = `${dateTuple[0]}`.slice(2);
127+
return `${monthStr} ${dayStr}, ${yearStr}`;
128+
}
129+
130+
/** Date format such as JAN 20, 2019. */
131+
export function dateTupleToMMMSpaceDDCommaSpaceYYYY(dateTuple) {
132+
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
133+
const dayStr = toTwoDigitString(dateTuple[2]);
134+
return `${monthStr} ${dayStr}, ${dateTuple[0]}`;
135+
}
136+
137+
/** Date format such as 20-01-2019. */
138+
export function dateTupleToDDDashMMDashYYYY(dateTuple) {
139+
const monthStr = toTwoDigitString(dateTuple[1]);
140+
const dayStr = toTwoDigitString(dateTuple[2]);
141+
return `${dayStr}-${monthStr}-${dateTuple[0]}`;
142+
}
143+
144+
/** Date format such as 20.01.2019. */
145+
export function dateTupleToDDDotMMDotYYYY(dateTuple) {
146+
const monthStr = toTwoDigitString(dateTuple[1]);
147+
const dayStr = toTwoDigitString(dateTuple[2]);
148+
return `${dayStr}.${monthStr}.${dateTuple[0]}`;
149+
}
150+
151+
/** Date format such as 2019.01.20. */
152+
export function dateTupleToYYYYDotMMDotDD(dateTuple) {
153+
const monthStr = toTwoDigitString(dateTuple[1]);
154+
const dayStr = toTwoDigitString(dateTuple[2]);
155+
return `${dateTuple[0]}.${monthStr}.${dayStr}`;
156+
}
157+
158+
159+
/** Date format such as 20190120 */
160+
export function dateTupleToYYYYMMDD(dateTuple) {
161+
const monthStr = toTwoDigitString(dateTuple[1]);
162+
const dayStr = toTwoDigitString(dateTuple[2]);
163+
return `${dateTuple[0]}${monthStr}${dayStr}`;
164+
}
165+
166+
/**
167+
* Date format such as 2019-01-20
168+
* (i.e., the ISO format and the conversion target).
169+
* */
170+
export function dateTupleToYYYYDashMMDashDD(dateTuple) {
171+
const monthStr = toTwoDigitString(dateTuple[1]);
172+
const dayStr = toTwoDigitString(dateTuple[2]);
173+
return `${dateTuple[0]}-${monthStr}-${dayStr}`;
174+
}
175+
176+
/**
177+
* Encode a number of input date strings as a `tf.Tensor`.
178+
*
179+
* The encoding is a sequence of one-hot vectors. The sequence is
180+
* padded at the end to the maximum possible length of any valid
181+
* input date strings. The padding value is zero.
182+
*
183+
* @param {string[]} dateStrings Input date strings. Each element of the array
184+
* must be one of the formats listed above. It is okay to mix multiple formats
185+
* in the array.
186+
* @returns {tf.Tensor} One-hot encoded characters as a `tf.Tensor`, of dtype
187+
* `float32` and shape `[numExamples, maxInputLength]`, where `maxInputLength`
188+
* is the maximum possible input length of all valid input date-string formats.
189+
*/
190+
export function encodeInputDateStrings(dateStrings) {
191+
const n = dateStrings.length;
192+
const x = tf.buffer([n, INPUT_LENGTH], 'float32');
193+
for (let i = 0; i < n; ++i) {
194+
for (let j = 0; j < INPUT_LENGTH; ++j) {
195+
if (j < dateStrings[i].length) {
196+
const char = dateStrings[i][j];
197+
const index = INPUT_VOCAB.indexOf(char);
198+
if (index === -1) {
199+
throw new Error(`Unknown char: ${char}`);
200+
}
201+
x.set(index, i, j);
202+
}
203+
}
204+
}
205+
return x.toTensor();
206+
}
207+
208+
/**
209+
* Encode a number of output date strings as a `tf.Tensor`.
210+
*
211+
* The encoding is a sequence of integer indices.
212+
*
213+
* @param {string[]} dateStrings An array of output date strings, must be in the
214+
* ISO date format (YYYY-MM-DD).
215+
* @returns {tf.Tensor} Integer indices of the characters as a `tf.Tensor`, of
216+
* dtype `int32` and shape `[numExamples, outputLength]`, where `outputLength`
217+
* is the length of the standard output format (i.e., `10`).
218+
*/
219+
export function encodeOutputDateStrings(dateStrings, oneHot = false) {
220+
const n = dateStrings.length;
221+
const x = tf.buffer([n, OUTPUT_LENGTH], 'int32');
222+
for (let i = 0; i < n; ++i) {
223+
tf.util.assert(
224+
dateStrings[i].length === OUTPUT_LENGTH,
225+
`Date string is not in ISO format: "${dateStrings[i]}"`);
226+
for (let j = 0; j < OUTPUT_LENGTH; ++j) {
227+
const char = dateStrings[i][j];
228+
const index = OUTPUT_VOCAB.indexOf(char);
229+
if (index === -1) {
230+
throw new Error(`Unknown char: ${char}`);
231+
}
232+
x.set(index, i, j);
233+
}
234+
}
235+
return x.toTensor();
236+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs';
19+
import * as dateFormat from './date_format';
20+
21+
describe('Date formats', () => {
22+
it('generateRandomDateTuple', () => {
23+
for (let i = 0; i < 100; ++i) {
24+
const [year, month, day] = dateFormat.generateRandomDateTuple();
25+
expect(Number.isInteger(year)).toEqual(true);
26+
expect(year).toBeGreaterThanOrEqual(1950);
27+
expect(year).toBeLessThan(2050);
28+
expect(Number.isInteger(month)).toEqual(true);
29+
expect(month).toBeGreaterThanOrEqual(1);
30+
expect(month).toBeLessThan(13);
31+
expect(Number.isInteger(day)).toEqual(true);
32+
expect(day).toBeGreaterThanOrEqual(1);
33+
expect(day).toBeLessThan(32);
34+
}
35+
});
36+
37+
it('DDMMMYYYY', () => {
38+
for (let i = 0; i < 10; ++i) {
39+
const str = dateFormat.dateTupleToDDMMMYYYY(
40+
dateFormat.generateRandomDateTuple());
41+
expect(str).toMatch(/^[0-3]\d[A-Z][A-Z][A-Z][1-2]\d\d\d$/);
42+
}
43+
});
44+
45+
it('MM/DD/YYYY', () => {
46+
for (let i = 0; i < 10; ++i) {
47+
const str = dateFormat.dateTupleToMMSlashDDSlashYYYY(
48+
dateFormat.generateRandomDateTuple());
49+
expect(str).toMatch(/^[0-1]\d\/[0-3]\d\/[1-2]\d\d\d$/);
50+
}
51+
});
52+
53+
it('MM/DD/YY', () => {
54+
for (let i = 0; i < 10; ++i) {
55+
const str = dateFormat.dateTupleToMMSlashDDSlashYY(
56+
dateFormat.generateRandomDateTuple());
57+
expect(str).toMatch(/^[0-1]\d\/[0-3]\d\/\d\d$/);
58+
}
59+
});
60+
61+
it('MMDDYY', () => {
62+
for (let i = 0; i < 10; ++i) {
63+
const str = dateFormat.dateTupleToMMDDYY(
64+
dateFormat.generateRandomDateTuple());
65+
expect(str).toMatch(/^[0-1]\d[0-3]\d\d\d$/);
66+
}
67+
});
68+
69+
it('MMMSpaceDDSpaceYY', () => {
70+
for (let i = 0; i < 10; ++i) {
71+
const str = dateFormat.dateTupleToMMMSpaceDDSpaceYY(
72+
dateFormat.generateRandomDateTuple());
73+
expect(str).toMatch(
74+
/^[A-Z][A-Z][A-Z] [0-3][0-9] [0-9][0-9]$/);
75+
}
76+
});
77+
78+
it('MMMSpaceDDSpaceYYYY', () => {
79+
for (let i = 0; i < 10; ++i) {
80+
const str = dateFormat.dateTupleToMMMSpaceDDSpaceYYYY(
81+
dateFormat.generateRandomDateTuple());
82+
expect(str).toMatch(
83+
/^[A-Z][A-Z][A-Z] [0-3][0-9] [0-9][0-9][0-9][0-9]$/);
84+
}
85+
});
86+
87+
it('MMMSpaceDDCommaSpaceYY', () => {
88+
for (let i = 0; i < 10; ++i) {
89+
const str = dateFormat.dateTupleToMMMSpaceDDCommaSpaceYY(
90+
dateFormat.generateRandomDateTuple());
91+
expect(str).toMatch(
92+
/^[A-Z][A-Z][A-Z] [0-3][0-9], [0-9][0-9]$/);
93+
}
94+
});
95+
96+
it('MMMSpaceDDCommaSpaceYYYY', () => {
97+
for (let i = 0; i < 10; ++i) {
98+
const str = dateFormat.dateTupleToMMMSpaceDDCommaSpaceYYYY(
99+
dateFormat.generateRandomDateTuple());
100+
expect(str).toMatch(
101+
/^[A-Z][A-Z][A-Z] [0-3][0-9], [0-9][0-9][0-9][0-9]$/);
102+
}
103+
});
104+
105+
it('MM-DD-YYYY', () => {
106+
for (let i = 0; i < 10; ++i) {
107+
const str = dateFormat.dateTupleToDDDashMMDashYYYY(
108+
dateFormat.generateRandomDateTuple());
109+
expect(str).toMatch(/^[0-3]\d-[0-1]\d-[1-2]\d\d\d$/);
110+
}
111+
});
112+
113+
it('YYYY.MM.DD', () => {
114+
for (let i = 0; i < 10; ++i) {
115+
const str = dateFormat.dateTupleToYYYYDotMMDotDD(
116+
dateFormat.generateRandomDateTuple());
117+
expect(str).toMatch(/^[1-2]\d\d\d\.[0-1]\d\.[0-3]\d$/);
118+
}
119+
});
120+
121+
it('DD.MM.YYYY', () => {
122+
for (let i = 0; i < 10; ++i) {
123+
const str = dateFormat.dateTupleToDDDotMMDotYYYY(
124+
dateFormat.generateRandomDateTuple());
125+
expect(str).toMatch(/^[0-3]\d\.[0-1]\d\.[1-2]\d\d\d$/);
126+
}
127+
});
128+
129+
it('YYYYMMDD', () => {
130+
for (let i = 0; i < 10; ++i) {
131+
const str = dateFormat.dateTupleToYYYYMMDD(
132+
dateFormat.generateRandomDateTuple());
133+
expect(str).toMatch(/^[1-2]\d\d\d[0-1]\d[0-3]\d$/);
134+
}
135+
});
136+
137+
it('YYYY-MM-DD', () => {
138+
for (let i = 0; i < 10; ++i) {
139+
const str = dateFormat.dateTupleToYYYYDashMMDashDD(
140+
dateFormat.generateRandomDateTuple());
141+
expect(str).toMatch(/^[1-2]\d\d\d-[0-1]\d-[0-3]\d$/);
142+
}
143+
});
144+
145+
it('Encode input string', () => {
146+
const str1 = dateFormat.dateTupleToDDMMMYYYY(
147+
dateFormat.generateRandomDateTuple());
148+
const str2 = dateFormat.dateTupleToMMSlashDDSlashYYYY(
149+
dateFormat.generateRandomDateTuple());
150+
const str3 = dateFormat.dateTupleToMMSlashDDSlashYY(
151+
dateFormat.generateRandomDateTuple());
152+
const encoded = dateFormat.encodeInputDateStrings([str1, str2, str3]);
153+
expect(encoded.min().dataSync()[0]).toEqual(0);
154+
expect(encoded.max().dataSync()[0]).toBeLessThan(
155+
dateFormat.INPUT_VOCAB.length);
156+
157+
const values = encoded.dataSync();
158+
let decodedStr = '';
159+
for (let i = 0; i < dateFormat.INPUT_LENGTH; ++i) {
160+
decodedStr += dateFormat.INPUT_VOCAB[values[i]];
161+
}
162+
expect(decodedStr.trim()).toEqual(str1);
163+
164+
decodedStr = '';
165+
for (let i = 0; i < dateFormat.INPUT_LENGTH; ++i) {
166+
decodedStr += dateFormat.INPUT_VOCAB[values[i + dateFormat.INPUT_LENGTH]];
167+
}
168+
expect(decodedStr.trim()).toEqual(str2);
169+
170+
decodedStr = '';
171+
for (let i = 0; i < dateFormat.INPUT_LENGTH; ++i) {
172+
decodedStr +=
173+
dateFormat.INPUT_VOCAB[values[i + dateFormat.INPUT_LENGTH * 2]];
174+
}
175+
expect(decodedStr.trim()).toEqual(str3);
176+
});
177+
178+
it('Encode output string', () => {
179+
const str1 = '2000-01-02';
180+
const str2 = '1983-08-30';
181+
const encoded = dateFormat.encodeOutputDateStrings([str1, str2]);
182+
expect(encoded.shape).toEqual([2, dateFormat.OUTPUT_LENGTH]);
183+
184+
const values = encoded.dataSync();
185+
let decodedStr = '';
186+
for (let i = 0; i < dateFormat.OUTPUT_LENGTH; ++i) {
187+
decodedStr += dateFormat.OUTPUT_VOCAB[values[i]];
188+
}
189+
expect(decodedStr.trim()).toEqual(str1);
190+
191+
decodedStr = '';
192+
for (let i = 0; i < dateFormat.OUTPUT_LENGTH; ++i) {
193+
decodedStr +=
194+
dateFormat.OUTPUT_VOCAB[values[i + dateFormat.OUTPUT_LENGTH]];
195+
}
196+
expect(decodedStr.trim()).toEqual(str2);
197+
});
198+
199+
it('Encode output string: oneHot', () => {
200+
const str1 = '2000-01-02';
201+
const str2 = '1983-08-30';
202+
const encoded = tf.oneHot(
203+
dateFormat.encodeOutputDateStrings([str1, str2]),
204+
dateFormat.OUTPUT_VOCAB.length);
205+
expect(encoded.shape).toEqual(
206+
[2, dateFormat.OUTPUT_LENGTH, dateFormat.OUTPUT_VOCAB.length]);
207+
expect(encoded.min().dataSync()[0]).toEqual(0);
208+
expect(encoded.max().dataSync()[0]).toEqual(1);
209+
210+
const values = encoded.argMax(-1).dataSync();
211+
212+
let decodedStr = '';
213+
for (let i = 0; i < dateFormat.OUTPUT_LENGTH; ++i) {
214+
decodedStr += dateFormat.OUTPUT_VOCAB[values[i]];
215+
}
216+
expect(decodedStr.trim()).toEqual(str1);
217+
218+
decodedStr = '';
219+
for (let i = 0; i < dateFormat.OUTPUT_LENGTH; ++i) {
220+
decodedStr +=
221+
dateFormat.OUTPUT_VOCAB[values[i + dateFormat.OUTPUT_LENGTH]];
222+
}
223+
expect(decodedStr.trim()).toEqual(str2);
224+
});
225+
});

‎date-conversion-attention/model.js

+176
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs';
19+
import * as dateFormat from './date_format';
20+
21+
/**
22+
* A custom layer used to obtain the last time step of an RNN sequential
23+
* output.
24+
*/
25+
class GetLastTimestepLayer extends tf.layers.Layer {
26+
constructor(config) {
27+
super(config || {});
28+
this.supportMasking = true;
29+
}
30+
31+
computeOutputShape(inputShape) {
32+
const outputShape = inputShape.slice();
33+
outputShape.splice(outputShape.length - 2, 1);
34+
return outputShape;
35+
}
36+
37+
call(input) {
38+
if (Array.isArray(input)) {
39+
input = input[0];
40+
}
41+
const inputRank = input.shape.length;
42+
tf.util.assert(inputRank === 3, `Invalid input rank: ${inputRank}`);
43+
return input.gather([input.shape[1] - 1], 1).squeeze([1]);
44+
}
45+
46+
static get className() {
47+
return 'GetLastTimestepLayer';
48+
}
49+
}
50+
tf.serialization.registerClass(GetLastTimestepLayer);
51+
52+
/**
53+
* Create an LSTM-based attention model for date conversion.
54+
*
55+
* @param {number} inputVocabSize Input vocabulary size. This includes
56+
* the padding symbol. In the context of this model, "vocabulary" means
57+
* the set of all unique characters that might appear in the input date
58+
* string.
59+
* @param {number} outputVocabSize Output vocabulary size. This includes
60+
* the padding and starting symbols. In the context of this model,
61+
* "vocabulary" means the set of all unique characters that might appear in
62+
* the output date string.
63+
* @param {number} inputLength Maximum input length (# of characters). Input
64+
* sequences shorter than the length must be padded at the end.
65+
* @param {number} outputLength Output length (# of characters).
66+
* @return {tf.Model} A compiled model instance.
67+
*/
68+
export function createModel(
69+
inputVocabSize, outputVocabSize, inputLength, outputLength) {
70+
const embeddingDims = 64;
71+
const lstmUnits = 64;
72+
73+
const encoderInput = tf.input({shape: [inputLength]});
74+
const decoderInput = tf.input({shape: [outputLength]});
75+
76+
let encoder = tf.layers.embedding({
77+
inputDim: inputVocabSize,
78+
outputDim: embeddingDims,
79+
inputLength,
80+
maskZero: true
81+
}).apply(encoderInput);
82+
encoder = tf.layers.lstm({
83+
units: lstmUnits,
84+
returnSequences: true
85+
}).apply(encoder);
86+
87+
const encoderLast = new GetLastTimestepLayer({
88+
name: 'encoderLast'
89+
}).apply(encoder);
90+
91+
let decoder = tf.layers.embedding({
92+
inputDim: outputVocabSize,
93+
outputDim: embeddingDims,
94+
inputLength: outputLength,
95+
maskZero: true
96+
}).apply(decoderInput);
97+
decoder = tf.layers.lstm({
98+
units: lstmUnits,
99+
returnSequences: true
100+
}).apply(decoder, {initialState: [encoderLast, encoderLast]});
101+
102+
let attention = tf.layers.dot({axes: [2, 2]}).apply([decoder, encoder]);
103+
attention = tf.layers.activation({
104+
activation: 'softmax',
105+
name: 'attention'
106+
}).apply(attention);
107+
108+
const context = tf.layers.dot({
109+
axes: [2, 1],
110+
name: 'context'
111+
}).apply([attention, encoder]);
112+
const deocderCombinedContext =
113+
tf.layers.concatenate().apply([context, decoder]);
114+
let output = tf.layers.timeDistributed({
115+
layer: tf.layers.dense({
116+
units: lstmUnits,
117+
activation: 'tanh'
118+
})
119+
}).apply(deocderCombinedContext);
120+
output = tf.layers.timeDistributed({
121+
layer: tf.layers.dense({
122+
units: outputVocabSize,
123+
activation: 'softmax'
124+
})
125+
}).apply(output);
126+
127+
const model = tf.model({
128+
inputs: [encoderInput, decoderInput],
129+
outputs: output
130+
});
131+
model.compile({
132+
loss: 'categoricalCrossentropy',
133+
optimizer: 'adam'
134+
});
135+
return model;
136+
}
137+
138+
/**
139+
* Perform sequence-to-sequence decoding for date conversion.
140+
*
141+
* @param {tf.Model} model The model to be used for the sequence-to-sequence
142+
* decoding, with two inputs:
143+
* 1. Encoder input of shape `[numExamples, inputLength]`
144+
* 2. Decoder input of shape `[numExamples, outputLength]`
145+
* and one output:
146+
* 1. Decoder softmax probability output of shape
147+
* `[numExamples, outputLength, outputVocabularySize]`
148+
* @param {string} inputStr Input date string to be converted.
149+
* @return {string} The converted date string.
150+
*/
151+
export async function runSeq2SeqInference(model, inputStr) {
152+
return tf.tidy(() => {
153+
const encoderInput = dateFormat.encodeInputDateStrings([inputStr]);
154+
const decoderInput = tf.buffer([1, dateFormat.OUTPUT_LENGTH]);
155+
decoderInput.set(dateFormat.START_CODE, 0, 0);
156+
157+
for (let i = 1; i < dateFormat.OUTPUT_LENGTH; ++i) {
158+
const predictOut = model.predict(
159+
[encoderInput, decoderInput.toTensor()]);
160+
const output = predictOut.argMax(2).dataSync()[i - 1];
161+
predictOut.dispose();
162+
decoderInput.set(output, 0, i);
163+
}
164+
const predictOut = model.predict(
165+
[encoderInput, decoderInput.toTensor()]);
166+
const finalOutput =
167+
predictOut.argMax(2).dataSync()[dateFormat.OUTPUT_LENGTH - 1];
168+
169+
let outputStr = '';
170+
for (let i = 1; i < decoderInput.shape[1]; ++i) {
171+
outputStr += dateFormat.OUTPUT_VOCAB[decoderInput.get(0, i)];
172+
}
173+
outputStr += dateFormat.OUTPUT_VOCAB[finalOutput];
174+
return outputStr;
175+
});
176+
}
+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tmp from 'tmp';
19+
import * as tf from '@tensorflow/tfjs';
20+
import {expectArraysClose} from '@tensorflow/tfjs-core/dist/test_util';
21+
import * as dateFormat from './date_format';
22+
import {createModel, runSeq2SeqInference} from './model';
23+
require('@tensorflow/tfjs-node');
24+
25+
describe('Model', () => {
26+
it('Created model can train', async () => {
27+
const inputVocabSize = 16;
28+
const outputVocabSize = 8;
29+
const inputLength = 6;
30+
const outputLength = 5;
31+
const model = createModel(
32+
inputVocabSize, outputVocabSize, inputLength, outputLength);
33+
34+
expect(model.inputs.length).toEqual(2);
35+
expect(model.inputs[0].shape).toEqual([null, inputLength]);
36+
expect(model.inputs[1].shape).toEqual([null, outputLength]);
37+
expect(model.outputs.length).toEqual(1);
38+
expect(model.outputs[0].shape).toEqual(
39+
[null, outputLength, outputVocabSize]);
40+
41+
const numExamples = 3;
42+
const encoderInputs = tf.ones([numExamples, inputLength]);
43+
const decoderInputs = tf.ones([numExamples, outputLength]);
44+
const decoderOutputs =
45+
tf.randomUniform([numExamples, outputLength, outputVocabSize]);
46+
const history = await model.fit(
47+
[encoderInputs, decoderInputs], decoderOutputs, {
48+
epochs: 2
49+
});
50+
expect(history.history.loss.length).toEqual(2);
51+
});
52+
53+
it('Model save-load roundtrip', async () => {
54+
const inputVocabSize = 16;
55+
const outputVocabSize = 8;
56+
const inputLength = 6;
57+
const outputLength = 5;
58+
const model = createModel(
59+
inputVocabSize, outputVocabSize, inputLength, outputLength);
60+
61+
const numExamples = 3;
62+
const encoderInputs = tf.ones([numExamples, inputLength]);
63+
const decoderInputs = tf.ones([numExamples, outputLength]);
64+
const y = model.predict([encoderInputs, decoderInputs]);
65+
66+
const saveDir = tmp.dirSync();
67+
await model.save(`file://${saveDir.name}`);
68+
const modelPrime = await tf.loadModel(`file://${saveDir.name}/model.json`);
69+
const yPrime = modelPrime.predict([encoderInputs, decoderInputs]);
70+
expectArraysClose(yPrime, y);
71+
});
72+
73+
it('seq2seq inference', async () => {
74+
const model = createModel(
75+
dateFormat.INPUT_VOCAB.length, dateFormat.OUTPUT_VOCAB.length,
76+
dateFormat.INPUT_LENGTH, dateFormat.OUTPUT_LENGTH);
77+
78+
const numTensors0 = tf.memory().numTensors;
79+
const output = await runSeq2SeqInference(model, '2019/01/18');
80+
// Assert no memory leak.
81+
expect(tf.memory().numTensors).toEqual(numTensors0);
82+
expect(output.length).toEqual(dateFormat.OUTPUT_LENGTH);
83+
});
84+
});
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"name": "tfjs-examples-date-conversion-attention",
3+
"version": "0.1.0",
4+
"description": "",
5+
"main": "index.js",
6+
"license": "Apache-2.0",
7+
"private": true,
8+
"engines": {
9+
"node": ">=8.11.0"
10+
},
11+
"scripts": {
12+
"postinstall": "yarn upgrade --pattern @tensorflow",
13+
"test": "babel-node run_tests.js",
14+
"train": "babel-node train.js"
15+
},
16+
"dependencies": {
17+
"@tensorflow/tfjs": "^0.15.1"
18+
},
19+
"devDependencies": {
20+
"@tensorflow/tfjs-node": "^0.3.0",
21+
"@tensorflow/tfjs-node-gpu": "^0.3.0",
22+
"argparse": "^1.0.10",
23+
"babel-cli": "^6.26.0",
24+
"babel-core": "^6.26.3",
25+
"babel-plugin-transform-runtime": "^6.23.0",
26+
"babel-polyfill": "^6.26.0",
27+
"babel-preset-env": "~1.7.0",
28+
"clang-format": "~1.2.2",
29+
"jasmine": "^3.2.0",
30+
"jasmine-core": "^3.2.1",
31+
"shelljs": "^0.8.3",
32+
"tmp": "^0.0.33",
33+
"yalc": "~1.0.0-pre.21"
34+
}
35+
}
+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
const jasmine_util = require('@tensorflow/tfjs-core/dist/jasmine_util');
19+
const runTests = require('../test_util').runTests;
20+
21+
runTests(jasmine_util, ['./*test.js']);

‎date-conversion-attention/train.js

+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* Training an attention LSTM sequence-to-sequence decoder to translate
20+
* various date formats into the ISO date format.
21+
*
22+
* Inspired by and loosely based on
23+
* https://github.com/wanasit/katakana/blob/master/notebooks/Attention-based%20Sequence-to-Sequence%20in%20Keras.ipynb
24+
*/
25+
26+
import * as fs from 'fs';
27+
import * as shelljs from 'shelljs';
28+
import * as argparse from 'argparse';
29+
import * as tf from '@tensorflow/tfjs';
30+
import * as dateFormat from './date_format';
31+
import {createModel, runSeq2SeqInference} from './model';
32+
33+
const INPUT_FNS = [
34+
dateFormat.dateTupleToDDMMMYYYY,
35+
dateFormat.dateTupleToMMDDYY,
36+
dateFormat.dateTupleToMMSlashDDSlashYY,
37+
dateFormat.dateTupleToMMSlashDDSlashYYYY,
38+
dateFormat.dateTupleToDDDashMMDashYYYY,
39+
dateFormat.dateTupleToMMMSpaceDDSpaceYY,
40+
dateFormat.dateTupleToMMMSpaceDDSpaceYYYY,
41+
dateFormat.dateTupleToMMMSpaceDDCommaSpaceYY,
42+
dateFormat.dateTupleToMMMSpaceDDCommaSpaceYYYY,
43+
dateFormat.dateTupleToDDDotMMDotYYYY,
44+
dateFormat.dateTupleToYYYYDotMMDotDD,
45+
dateFormat.dateTupleToYYYYMMDD,
46+
dateFormat.dateTupleToYYYYDashMMDashDD
47+
]; // TODO(cais): Add more formats if necessary.
48+
49+
/**
50+
* Generate sets of data for training.
51+
*
52+
* @param {number} trainSplit Trainining split. Must be >0 and <1.
53+
* @param {number} valSplit Validatoin split. Must be >0 and <1.
54+
* @return An `Object` consisting of
55+
* - trainEncoderInput, as a `tf.Tensor` of shape
56+
* `[numTrainExapmles, inputLength]`
57+
* - trainDecoderInput, as a `tf.Tensor` of shape
58+
* `[numTrainExapmles, outputLength]`. The first element of every
59+
* example has been set as START_CODE (the sequence-start symbol).
60+
* - trainDecoderOuptut, as a one-hot encoded `tf.Tensor` of shape
61+
* `[numTrainExamples, outputLength, outputVocabSize]`.
62+
* - valEncoderInput, same as trainEncoderInput, but for the validation set.
63+
* - valDecoderInput, same as trainDecoderInput, but for the validation set.
64+
* - valDecoderOutput, same as trainDecoderOuptut, but for the validation
65+
* set.
66+
* - testDateTuples, date tuples ([year, month, day]) for the test set.
67+
*/
68+
export function generateDataForTraining(trainSplit = 0.8, valSplit = 0.15) {
69+
tf.util.assert(
70+
trainSplit > 0 && valSplit > 0 && trainSplit + valSplit <= 1,
71+
`Invalid trainSplit (${trainSplit}) and valSplit (${valSplit})`);
72+
73+
const dateTuples = [];
74+
const MIN_YEAR = 1950;
75+
const MAX_YEAR = 2050;
76+
for (let year = MIN_YEAR; year < MAX_YEAR; ++year) {
77+
for (let month = 1; month <= 12; ++month) {
78+
for (let day = 1; day <= 28; ++day) {
79+
dateTuples.push([year, month, day]);
80+
}
81+
}
82+
}
83+
tf.util.shuffle(dateTuples);
84+
85+
const numTrain = Math.floor(dateTuples.length * trainSplit);
86+
const numVal = Math.floor(dateTuples.length * valSplit);
87+
88+
function dateTuplesToTensor(dateTuples) {
89+
return tf.tidy(() => {
90+
const inputs = INPUT_FNS.map(fn => dateTuples.map(tuple => fn(tuple)));
91+
const inputStrings = [];
92+
inputs.forEach(inputs => inputStrings.push(...inputs));
93+
const encoderInput =
94+
dateFormat.encodeInputDateStrings(inputStrings);
95+
const trainTargetStrings = dateTuples.map(
96+
tuple => dateFormat.dateTupleToYYYYDashMMDashDD(tuple));
97+
let decoderInput =
98+
dateFormat.encodeOutputDateStrings(trainTargetStrings)
99+
.asType('float32');
100+
// One-step time shift: The decoder input is shifted to the left by
101+
// one time step with respect to the encoder input. This accounts for
102+
// the step-by-step decoding that happens during inference time.
103+
decoderInput = tf.concat([
104+
tf.ones([decoderInput.shape[0], 1]).mul(dateFormat.START_CODE),
105+
decoderInput.slice(
106+
[0, 0], [decoderInput.shape[0], decoderInput.shape[1] - 1])
107+
], 1).tile([INPUT_FNS.length, 1]);
108+
const decoderOutput = tf.oneHot(
109+
dateFormat.encodeOutputDateStrings(trainTargetStrings),
110+
dateFormat.OUTPUT_VOCAB.length).tile([INPUT_FNS.length, 1, 1]);
111+
return {encoderInput, decoderInput, decoderOutput};
112+
});
113+
}
114+
115+
const {
116+
encoderInput: trainEncoderInput,
117+
decoderInput: trainDecoderInput,
118+
decoderOutput: trainDecoderOutput
119+
} = dateTuplesToTensor(dateTuples.slice(0, numTrain));
120+
const {
121+
encoderInput: valEncoderInput,
122+
decoderInput: valDecoderInput,
123+
decoderOutput: valDecoderOutput
124+
} = dateTuplesToTensor(dateTuples.slice(numTrain, numTrain + numVal));
125+
const testDateTuples =
126+
dateTuples.slice(numTrain + numVal, dateTuples.length);
127+
return {
128+
trainEncoderInput,
129+
trainDecoderInput,
130+
trainDecoderOutput,
131+
valEncoderInput,
132+
valDecoderInput,
133+
valDecoderOutput,
134+
testDateTuples
135+
};
136+
}
137+
138+
function parseArguments() {
139+
const argParser = new argparse.ArgumentParser({
140+
description:
141+
'Train an attention-based date-conversion model in TensorFlow.js'
142+
});
143+
argParser.addArgument('--gpu', {
144+
action: 'storeTrue',
145+
help: 'Use tfjs-node-gpu to train the model. Requires CUDA/CuDNN.'
146+
});
147+
argParser.addArgument('--epochs', {
148+
type: 'int',
149+
defaultValue: 2,
150+
help: 'Number of epochs to train the model for'
151+
});
152+
argParser.addArgument('--batchSize', {
153+
type: 'int',
154+
defaultValue: 128,
155+
help: 'Batch size to be used during model training'
156+
});
157+
argParser.addArgument('--savePath', {
158+
type: 'string',
159+
defaultValue: './dist/model',
160+
});
161+
return argParser.parseArgs();
162+
}
163+
164+
async function run() {
165+
const args = parseArguments();
166+
if (args.gpu) {
167+
console.log('Using GPU');
168+
require('@tensorflow/tfjs-node-gpu');
169+
} else {
170+
console.log('Using CPU');
171+
require('@tensorflow/tfjs-node');
172+
}
173+
174+
const model = createModel(
175+
dateFormat.INPUT_VOCAB.length, dateFormat.OUTPUT_VOCAB.length,
176+
dateFormat.INPUT_LENGTH, dateFormat.OUTPUT_LENGTH);
177+
model.summary();
178+
179+
const {
180+
trainEncoderInput,
181+
trainDecoderInput,
182+
trainDecoderOutput,
183+
valEncoderInput,
184+
valDecoderInput,
185+
valDecoderOutput,
186+
testDateTuples
187+
} = generateDataForTraining();
188+
189+
await model.fit(
190+
[trainEncoderInput, trainDecoderInput], trainDecoderOutput, {
191+
epochs: args.epochs,
192+
batchSize: args.batchSize,
193+
shuffle: true,
194+
validationData: [[valEncoderInput, valDecoderInput], valDecoderOutput]
195+
});
196+
197+
// Save the model.
198+
if (args.savePath != null && args.savePath.length) {
199+
if (!fs.existsSync(args.savePath)) {
200+
shelljs.mkdir('-p', args.savePath);
201+
}
202+
const saveURL = `file://${args.savePath}`
203+
await model.save(saveURL);
204+
console.log(`Saved model to ${saveURL}`);
205+
}
206+
207+
// Run seq2seq inference tests and print the results to console.
208+
const numTests = 10;
209+
for (let n = 0; n < numTests; ++n) {
210+
for (const testInputFn of INPUT_FNS) {
211+
const inputStr = testInputFn(testDateTuples[n]);
212+
console.log('\n-----------------------');
213+
console.log(`Input string: ${inputStr}`);
214+
const correctAnswer =
215+
dateFormat.dateTupleToYYYYDashMMDashDD(testDateTuples[n]);
216+
console.log(`Correct answer: ${correctAnswer}`);
217+
218+
const outputStr = await runSeq2SeqInference(model, inputStr);
219+
const isCorrect = outputStr === correctAnswer;
220+
console.log(
221+
`Model output: ${outputStr} (${isCorrect ? 'OK' : 'WRONG'})` );
222+
}
223+
}
224+
}
225+
226+
if (require.main === module) {
227+
run();
228+
}
+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as dateFormat from './date_format';
19+
import {generateDataForTraining} from './train';
20+
21+
describe('generateBatchesForTraining', () => {
22+
it('generateDataForTraining', () => {
23+
const {
24+
trainEncoderInput,
25+
trainDecoderInput,
26+
trainDecoderOutput,
27+
valEncoderInput,
28+
valDecoderInput,
29+
valDecoderOutput,
30+
testDateTuples
31+
} = generateDataForTraining(0.5, 0.25);
32+
const numTrain = trainEncoderInput.shape[0];
33+
const numVal = valEncoderInput.shape[0];
34+
expect(numTrain / numVal).toBeCloseTo(2);
35+
expect(trainEncoderInput.shape).toEqual(
36+
[numTrain, dateFormat.INPUT_LENGTH]);
37+
expect(trainDecoderInput.shape).toEqual(
38+
[numTrain, dateFormat.OUTPUT_LENGTH]);
39+
expect(trainDecoderOutput.shape).toEqual(
40+
[numTrain, dateFormat.OUTPUT_LENGTH, dateFormat.OUTPUT_VOCAB.length]);
41+
expect(valEncoderInput.shape).toEqual(
42+
[numVal, dateFormat.INPUT_LENGTH]);
43+
expect(valDecoderInput.shape).toEqual(
44+
[numVal, dateFormat.OUTPUT_LENGTH]);
45+
expect(valDecoderOutput.shape).toEqual(
46+
[numVal, dateFormat.OUTPUT_LENGTH, dateFormat.OUTPUT_VOCAB.length]);
47+
expect(testDateTuples[0].length).toEqual(3);
48+
expect(testDateTuples[testDateTuples.length - 1].length).toEqual(3);
49+
});
50+
});

‎date-conversion-attention/yarn.lock

+2,864
Large diffs are not rendered by default.

‎package.json

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"name": "tfjs-examples",
3+
"version": "0.0.1",
4+
"devDependencies": {
5+
"jasmine": "~3.1.0",
6+
"yalc": "~1.0.0-pre.21"
7+
},
8+
"license": "Apache-2.0"
9+
}

‎test_util.js

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* Execute all unit tests in the current directory. Takes a jasmine_util from
20+
* tfjs-core so that we use the tfjs-core module from the right test directory.
21+
*/
22+
function runTests(jasmineUtil, specFiles) {
23+
// tslint:disable-next-line:no-require-imports
24+
const jasmineConstructor = require('jasmine');
25+
26+
Error.stackTraceLimit = Infinity;
27+
28+
process.on('unhandledRejection', e => {
29+
throw e;
30+
});
31+
32+
jasmineUtil.setTestEnvs(
33+
[{name: 'node', factory: jasmineUtil.CPU_FACTORY, features: {}}]);
34+
35+
const runner = new jasmineConstructor();
36+
runner.loadConfig({spec_files: specFiles, random: false});
37+
runner.execute();
38+
}
39+
40+
module.exports = {runTests};

‎yarn.lock

+544
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)
Please sign in to comment.