Skip to content

Commit cafe496

Browse files
yashk2810copybara-github
authored andcommitted
Adding nmt_with_attention
PiperOrigin-RevId: 240433141
1 parent 6f3c9b0 commit cafe496

File tree

8 files changed

+802
-1
lines changed

8 files changed

+802
-1
lines changed

tensorflow_examples/models/densenet/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class Train(object):
3131
Args:
3232
epochs: Number of epochs
3333
enable_function: If True, wraps the train_step and test_step in tf.function
34+
model: Densenet model.
3435
"""
3536

3637
def __init__(self, epochs, enable_function, model):
@@ -154,7 +155,6 @@ def run_main(argv):
154155
"""
155156
del argv
156157
kwargs = utils.flags_dict()
157-
print (kwargs)
158158
main(**kwargs)
159159

160160

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for distributed nmt_with_attention."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import time
22+
import tensorflow as tf # TF2
23+
from tensorflow_examples.models.nmt_with_attention import distributed_train
24+
from tensorflow_examples.models.nmt_with_attention import utils
25+
assert tf.__version__.startswith('2')
26+
27+
28+
class NmtDistributedBenchmark(tf.test.Benchmark):
29+
30+
def __init__(self, output_dir=None, **kwargs):
31+
self.output_dir = output_dir
32+
33+
def benchmark_one_epoch_1_gpu(self):
34+
kwargs = utils.get_common_kwargs()
35+
kwargs.update({'enable_function': False})
36+
self._run_and_report_benchmark(**kwargs)
37+
38+
def benchmark_one_epoch_1_gpu_function(self):
39+
kwargs = utils.get_common_kwargs()
40+
self._run_and_report_benchmark(**kwargs)
41+
42+
def benchmark_ten_epochs_2_gpus(self):
43+
kwargs = utils.get_common_kwargs()
44+
kwargs.update({'epochs': 10, 'num_gpu': 2, 'batch_size': 128})
45+
self._run_and_report_benchmark(**kwargs)
46+
47+
def _run_and_report_benchmark(self, **kwargs):
48+
start_time_sec = time.time()
49+
train_loss, test_loss = distributed_train.main(**kwargs)
50+
wall_time_sec = time.time() - start_time_sec
51+
52+
extras = {'train_loss': train_loss,
53+
'test_loss': test_loss}
54+
55+
self.report_benchmark(
56+
wall_time=wall_time_sec, extras=extras)
57+
58+
if __name__ == '__main__':
59+
tf.test.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Distributed Train.
16+
"""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
from absl import app
23+
from absl import flags
24+
import tensorflow as tf # TF2
25+
from tensorflow_examples.models.nmt_with_attention import nmt
26+
from tensorflow_examples.models.nmt_with_attention import utils
27+
from tensorflow_examples.models.nmt_with_attention.train import Train
28+
assert tf.__version__.startswith('2')
29+
30+
FLAGS = flags.FLAGS
31+
32+
# if additional flags are needed, define it here.
33+
flags.DEFINE_integer('num_gpu', 1, 'Number of GPUs to use')
34+
35+
36+
class DistributedTrain(Train):
37+
"""Distributed Train class.
38+
39+
Args:
40+
epochs: Number of epochs.
41+
enable_function: Decorate function with tf.function.
42+
encoder: Encoder.
43+
decoder: Decoder.
44+
inp_lang: Input language tokenizer.
45+
targ_lang: Target language tokenizer.
46+
batch_size: Batch size.
47+
"""
48+
49+
def __init__(self, epochs, enable_function, encoder, decoder, inp_lang,
50+
targ_lang, batch_size):
51+
Train.__init__(
52+
self, epochs, enable_function, encoder, decoder, inp_lang, targ_lang,
53+
batch_size)
54+
55+
def training_loop(self, train_iterator, test_iterator,
56+
num_train_steps_per_epoch, num_test_steps_per_epoch,
57+
strategy):
58+
"""Custom training and testing loop.
59+
60+
Args:
61+
train_iterator: Training iterator created using strategy
62+
test_iterator: Testing iterator created using strategy
63+
num_train_steps_per_epoch: number of training steps in an epoch.
64+
num_test_steps_per_epoch: number of test steps in an epoch.
65+
strategy: Distribution strategy
66+
67+
Returns:
68+
train_loss, test_loss
69+
"""
70+
71+
# this code is expected to change.
72+
def distributed_train():
73+
return strategy.experimental_run(self.train_step, train_iterator)
74+
75+
def distributed_test():
76+
return strategy.experimental_run(self.test_step, test_iterator)
77+
78+
if self.enable_function:
79+
distributed_train = tf.function(distributed_train)
80+
distributed_test = tf.function(distributed_test)
81+
82+
template = 'Epoch: {}, Train Loss: {}, Test Loss: {}'
83+
84+
for epoch in range(self.epochs):
85+
self.train_loss_metric.reset_states()
86+
self.test_loss_metric.reset_states()
87+
88+
train_iterator.initialize()
89+
for _ in range(num_train_steps_per_epoch):
90+
distributed_train()
91+
92+
test_iterator.initialize()
93+
for _ in range(num_test_steps_per_epoch):
94+
distributed_test()
95+
96+
print (template.format(epoch,
97+
self.train_loss_metric.result().numpy(),
98+
self.test_loss_metric.result().numpy()))
99+
100+
return (self.train_loss_metric.result().numpy(),
101+
self.test_loss_metric.result().numpy())
102+
103+
104+
def run_main(argv):
105+
del argv
106+
kwargs = utils.flags_dict()
107+
kwargs.update({'num_gpu': FLAGS.num_gpu})
108+
main(**kwargs)
109+
110+
111+
def main(epochs, enable_function, buffer_size, batch_size, download_path,
112+
num_examples=70000, embedding_dim=256, enc_units=1024, dec_units=1024,
113+
num_gpu=1):
114+
115+
devices = ['/device:GPU:{}'.format(i) for i in range(num_gpu)]
116+
strategy = tf.distribute.MirroredStrategy(devices)
117+
118+
with strategy.scope():
119+
file_path = utils.download(download_path)
120+
train_ds, test_ds, inp_lang, targ_lang = utils.create_dataset(
121+
file_path, num_examples, buffer_size, batch_size)
122+
vocab_inp_size = len(inp_lang.word_index) + 1
123+
vocab_tar_size = len(targ_lang.word_index) + 1
124+
125+
num_train_steps_per_epoch = tf.data.experimental.cardinality(train_ds)
126+
num_test_steps_per_epoch = tf.data.experimental.cardinality(test_ds)
127+
128+
train_iterator = strategy.make_dataset_iterator(train_ds)
129+
test_iterator = strategy.make_dataset_iterator(test_ds)
130+
131+
encoder = nmt.Encoder(vocab_inp_size, embedding_dim, enc_units, batch_size)
132+
decoder = nmt.Decoder(vocab_tar_size, embedding_dim, dec_units, batch_size)
133+
134+
train_obj = DistributedTrain(epochs, enable_function, encoder, decoder,
135+
inp_lang, targ_lang, batch_size)
136+
print ('Training ...')
137+
return train_obj.training_loop(train_iterator,
138+
test_iterator,
139+
num_train_steps_per_epoch,
140+
num_test_steps_per_epoch,
141+
strategy)
142+
143+
if __name__ == '__main__':
144+
utils.nmt_flags()
145+
app.run(run_main)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Neural Machine Translation with Attention.
16+
"""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import tensorflow as tf # TF2
23+
24+
25+
class Encoder(tf.keras.Model):
26+
"""Encoder.
27+
28+
Args:
29+
vocab_size: Vocabulary size.
30+
embedding_dim: Embedding dimension.
31+
enc_units: Number of encoder units.
32+
batch_sz: Batch size.
33+
"""
34+
35+
def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
36+
super(Encoder, self).__init__()
37+
self.batch_sz = batch_sz
38+
self.enc_units = enc_units
39+
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
40+
self.gru = tf.keras.layers.GRU(self.enc_units,
41+
return_sequences=True,
42+
return_state=True,
43+
recurrent_initializer='glorot_uniform')
44+
45+
def call(self, x, hidden):
46+
x = self.embedding(x)
47+
output, state = self.gru(x, initial_state=hidden)
48+
return output, state
49+
50+
def initialize_hidden_state(self):
51+
return tf.zeros((self.batch_sz, self.enc_units))
52+
53+
54+
class BahdanauAttention(tf.keras.Model):
55+
"""Bahdanau Attention.
56+
57+
Args:
58+
units: Number of dense units.
59+
"""
60+
61+
def __init__(self, units):
62+
super(BahdanauAttention, self).__init__()
63+
self.w1 = tf.keras.layers.Dense(units)
64+
self.w2 = tf.keras.layers.Dense(units)
65+
self.v = tf.keras.layers.Dense(1)
66+
67+
def call(self, query, values):
68+
# hidden shape == (batch_size, hidden size)
69+
# hidden_with_time_axis shape == (batch_size, 1, hidden size)
70+
# we are doing this to perform addition to calculate the score
71+
hidden_with_time_axis = tf.expand_dims(query, 1)
72+
73+
# score shape == (batch_size, max_length, hidden_size)
74+
score = self.v(tf.nn.tanh(
75+
self.w1(values) + self.w2(hidden_with_time_axis)))
76+
77+
# attention_weights shape == (batch_size, max_length, 1)
78+
# we get 1 at the last axis because we are applying score to self.V
79+
attention_weights = tf.nn.softmax(score, axis=1)
80+
81+
# context_vector shape after sum == (batch_size, hidden_size)
82+
context_vector = attention_weights * values
83+
context_vector = tf.reduce_sum(context_vector, axis=1)
84+
85+
return context_vector, attention_weights
86+
87+
88+
class Decoder(tf.keras.Model):
89+
"""Decoder.
90+
91+
Args:
92+
vocab_size: Vocabulary size.
93+
embedding_dim: Embedding dimension.
94+
dec_units: Number of decoder units.
95+
batch_sz: Batch size.
96+
"""
97+
98+
def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
99+
super(Decoder, self).__init__()
100+
self.batch_sz = batch_sz
101+
self.dec_units = dec_units
102+
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
103+
self.gru = tf.keras.layers.GRU(self.dec_units,
104+
return_sequences=True,
105+
return_state=True,
106+
recurrent_initializer='glorot_uniform')
107+
self.fc = tf.keras.layers.Dense(vocab_size)
108+
109+
# used for attention
110+
self.attention = BahdanauAttention(self.dec_units)
111+
112+
def call(self, x, hidden, enc_output):
113+
# enc_output shape == (batch_size, max_length, hidden_size)
114+
context_vector, attention_weights = self.attention(hidden, enc_output)
115+
116+
# x shape after passing through embedding == (batch_size, 1, embedding_dim)
117+
x = self.embedding(x)
118+
119+
# x shape after concatenation == (batch_size, 1, embedding_dim+hidden_size)
120+
x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
121+
122+
# passing the concatenated vector to the GRU
123+
output, state = self.gru(x)
124+
125+
# output shape == (batch_size * 1, hidden_size)
126+
output = tf.reshape(output, (-1, output.shape[2]))
127+
128+
# output shape == (batch_size, vocab)
129+
x = self.fc(output)
130+
131+
return x, state, attention_weights

0 commit comments

Comments
 (0)