Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
ab09b6c
Initial check-in of date-conversion-attention
caisq Jan 16, 2019
40b5ad7
Add encodeInputDateStrings
caisq Jan 16, 2019
d28c8d5
Merge branch 'master' into attention-date-translation
caisq Jan 17, 2019
fa61c9f
Adding model construction code
caisq Jan 17, 2019
b268902
WIP: Toward model training
caisq Jan 17, 2019
78ffca5
Add oneHot option to encodeOutputDateStrings();
caisq Jan 17, 2019
5863f20
Add more assertions to test
caisq Jan 17, 2019
a35e440
Towards training
caisq Jan 17, 2019
ce53fc7
Fix bugs in training and inference
caisq Jan 18, 2019
6d67a7b
train.js: arg parsing, seq2seq refactoring, doc strings
caisq Jan 18, 2019
a5b23d8
Refactoring: use babel-node for browser compatibility
caisq Jan 18, 2019
d30b447
Refactoring and more tests
caisq Jan 18, 2019
a915d23
Add more tests
caisq Jan 18, 2019
cb43a7f
Add dateTupleToDDDashMMDashYYYY() and tests
caisq Jan 18, 2019
2b37de7
Add README.md
caisq Jan 18, 2019
132afd5
Beef up README.md a little
caisq Jan 18, 2019
ab861ec
Add testing-related comments to README.md
caisq Jan 18, 2019
029243a
README.md: Fix a typo
caisq Jan 18, 2019
8927b07
README.md: Typo fixes
caisq Jan 18, 2019
51e7411
Register custom layer properly; unit test
caisq Jan 18, 2019
64f5c48
Add four more input date formats
caisq Jan 19, 2019
e1bab13
Merge branch 'master' into attention-date-translation
caisq Jan 25, 2019
d5a0d73
Merge branch 'master' into attention-date-translation
caisq Jan 27, 2019
7cf51b1
Address review comments 1/2
caisq Jan 27, 2019
d08ae19
Merge branch 'attention-date-translation' of github.com:caisq/tfjs-ex…
caisq Jan 27, 2019
d71a38a
save
caisq Jan 27, 2019
ad7ea77
Add unit tests for new date formats
caisq Jan 27, 2019
8972918
Respond to reviewer comments
caisq Jan 27, 2019
c62c9ee
Respond to further review comments
caisq Jan 28, 2019
b9e48f0
All instance of strPrime --> decodedStr
caisq Jan 28, 2019
8961a68
Remove cruft
caisq Jan 28, 2019
3c762c4
Merge branch 'master' into attention-date-translation
caisq Feb 5, 2019
b4870b1
Update to tfjs-node 0.3.0 and tfjs-layers 0.15.1
caisq Feb 9, 2019
0cfbd7a
Merge branch 'attention-date-translation' of github.com:caisq/tfjs-ex…
caisq Feb 9, 2019
6be30c7
Merge branch 'master' into attention-date-translation
caisq Feb 9, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions date-conversion-attention/.babelrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"presets": [
[
"env",
{
"esmodules": false,
"targets": {
"browsers": [
"> 3%"
]
}
}
]
],
"plugins": [
"transform-runtime"
]
}
57 changes: 57 additions & 0 deletions date-conversion-attention/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# TensorFlow.js Example: Date Conversion Through an LSTM-Attention Model

## Overview

This example shows how to use TensorFlow.js to train a model based on
long short-term memory (LSTM) and the attention mechanism to achieve
a task of converting various commonly seen date formats (e.g., 01/18/2019,
18JAN2019, 18-01-2019) to the ISO date format (i.e., 2019-01-18).

We demonstrate the full machine-learning workflow, consisting of
data engineering, server-side model training, client-side inference,
model visualization, and unit testing in this example.

The training data is synthesized programmatically.

## Model training in Node.js

For efficiency, the training of the model happens outside the browser
in Node.js, using tfjs-node or tfjs-node-gpu.

To run the training job, do

```sh
yarn
yarn train
```

By default, the training uses tfjs-node, which runs on the CPU.
If you have a CUDA-enabled GPU and have the CUDA and CuDNN libraries
set up properly on your system, you can run the training on the GPU
by:

```sh
yarn
yarn train --gpu
```

## Using the model in the browser

TODO(cais): Implement it.

### Visualization of the attention mechanism

TODO(cais): Implement it.

## Running unit tests

The data and model code in this example are covered by unit tests.
To run the unit tests:

```sh
cd ../
yarn
cd date-conversion-attention
yarn
yarn test
```
169 changes: 169 additions & 0 deletions date-conversion-attention/date_format.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

const tf = require('@tensorflow/tfjs');

const MONTH_NAMES_FULL = [
'January', 'February', 'March', 'April', 'May', 'June', 'July', 'August',
'September', 'October', 'November', 'December'
];
const MONTH_NAMES_3LETTER =
MONTH_NAMES_FULL.map(name => name.slice(0, 3).toUpperCase());

const MIN_DATE = new Date('1950-01-01').getTime();
const MAX_DATE = new Date('2050-01-01').getTime();

export const INPUT_LENGTH = 12 // Maximum length of all input formats.
export const OUTPUT_LENGTH = 10 // Length of 'YYYY-MM-DD'.

// Use "\n" for padding for both input and output. It has to be at the
// beginning so that `mask_zero=True` can be used in the keras model.
export const INPUT_VOCAB = '\n0123456789/-, ' +
MONTH_NAMES_3LETTER.join('')
.split('')
.filter(function(item, i, ar) {
return ar.indexOf(item) === i;
})
.join('');

// OUTPUT_VOCAB includes an start-of-sequence (SOS) token, represented as
// '\t'.
export const OUTPUT_VOCAB = '\n\t0123456789-';

export const START_CODE = 1;

/**
* Generate a random date.
*
* @return {[number, number, number]} Year as an integer, month as an
* integer >= 1 and <= 12, day as an integer >= 1.
*/
export function generateRandomDateTuple() {
const date = new Date(Math.random() * (MAX_DATE - MIN_DATE) + MIN_DATE);
return [date.getFullYear(), date.getMonth() + 1, date.getDate()];
}

function toTwoDigitString(num) {
return num < 10 ? `0${num}` : `${num}`;
}

export function dateTupleToDDMMMYYYY(dateTuple) {
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
const dayStr = toTwoDigitString(dateTuple[2]);
return `${dayStr}${monthStr}${dateTuple[0]}`;
}

export function dateTupleToMMSlashDDSlashYYYY(dateTuple) {
const monthStr = toTwoDigitString(dateTuple[1]);
const dayStr = toTwoDigitString(dateTuple[2]);
return `${monthStr}/${dayStr}/${dateTuple[0]}`;
}

export function dateTupleToMMSlashDDSlashYY(dateTuple) {
const monthStr = toTwoDigitString(dateTuple[1]);
const dayStr = toTwoDigitString(dateTuple[2]);
const yearStr = `${dateTuple[0]}`.slice(2);
return `${monthStr}/${dayStr}/${yearStr}`;
}

export function dateTupleToMMDDYY(dateTuple) {
const monthStr = toTwoDigitString(dateTuple[1]);
const dayStr = toTwoDigitString(dateTuple[2]);
const yearStr = `${dateTuple[0]}`.slice(2);
return `${monthStr}${dayStr}${yearStr}`;
}

export function dateTupleToMMMSpaceDDSpaceYY(dateTuple) {
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
const dayStr = toTwoDigitString(dateTuple[2]);
const yearStr = `${dateTuple[0]}`.slice(2);
return `${monthStr} ${dayStr} ${yearStr}`;
}

export function dateTupleToMMMSpaceDDSpaceYYYY(dateTuple) {
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
const dayStr = toTwoDigitString(dateTuple[2]);
return `${monthStr} ${dayStr} ${dateTuple[0]}`;
}

export function dateTupleToMMMSpaceDDCommaSpaceYY(dateTuple) {
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
const dayStr = toTwoDigitString(dateTuple[2]);
const yearStr = `${dateTuple[0]}`.slice(2);
return `${monthStr} ${dayStr}, ${yearStr}`;
}

export function dateTupleToMMMSpaceDDCommaSpaceYYYY(dateTuple) {
const monthStr = MONTH_NAMES_3LETTER[dateTuple[1] - 1];
const dayStr = toTwoDigitString(dateTuple[2]);
return `${monthStr} ${dayStr}, ${dateTuple[0]}`;
}

export function dateTupleToDDDashMMDashYYYY(dateTuple) {
const monthStr = toTwoDigitString(dateTuple[1]);
const dayStr = toTwoDigitString(dateTuple[2]);
return `${dayStr}-${monthStr}-${dateTuple[0]}`;
}

export function dateTupleToYYYYDashMMDashDD(dateTuple) {
const monthStr = toTwoDigitString(dateTuple[1]);
const dayStr = toTwoDigitString(dateTuple[2]);
return `${dateTuple[0]}-${monthStr}-${dayStr}`;
}

export function encodeInputDateStrings(dateStrings) {
const n = dateStrings.length;
const x = tf.buffer([n, INPUT_LENGTH], 'float32');
for (let i = 0; i < n; ++i) {
for (let j = 0; j < INPUT_LENGTH; ++j) {
if (j < dateStrings[i].length) {
const char = dateStrings[i][j];
const index = INPUT_VOCAB.indexOf(char);
if (index === -1) {
throw new Error(`Unknown char: ${char}`);
}
x.set(index, i, j);
}
}
}
return x.toTensor();
}

export function encodeOutputDateStrings(dateStrings, oneHot = false) {
const n = dateStrings.length;
const x =
oneHot ? tf.buffer([n, OUTPUT_LENGTH, OUTPUT_VOCAB.length], 'float32') :
tf.buffer([n, OUTPUT_LENGTH], 'float32');
for (let i = 0; i < n; ++i) {
tf.util.assert(
dateStrings[i].length === OUTPUT_LENGTH,
`Date string is not in ISO format: "${dateStrings[i]}"`);
for (let j = 0; j < OUTPUT_LENGTH; ++j) {
const char = dateStrings[i][j];
const index = OUTPUT_VOCAB.indexOf(char);
if (index === -1) {
throw new Error(`Unknown char: ${char}`);
}
if (oneHot) {
x.set(1, i, j, index);
} else {
x.set(index, i, j);
}
}
}
return x.toTensor();
}
Loading