Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ffc5800

Browse files
authored
Merge pull request #325 from rsepassi/push
v1.2.4
2 parents feb752c + 583356d commit ffc5800

38 files changed

+2146
-473
lines changed

.travis.yml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,22 @@ before_install:
88
install:
99
- pip install tensorflow
1010
- pip install .[tests]
11+
env:
12+
global:
13+
- T2T_PROBLEM=algorithmic_reverse_binary40_test
14+
- T2T_DATA_DIR=/tmp/t2t-data
15+
- T2T_TRAIN_DIR=/tmp/t2t-train
1116
script:
12-
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/utils/trainer_utils_test.py --ignore=tensor2tensor/problems_test.py
17+
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/utils/trainer_utils_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/tpu/tpu_trainer_lib_test.py
1318
- pytest tensor2tensor/utils/registry_test.py
1419
- pytest tensor2tensor/utils/trainer_utils_test.py
20+
- t2t-datagen 2>&1 | grep translate && echo passed
21+
- python -c "from tensor2tensor.models import transformer; print(transformer.Transformer.__name__)"
22+
- t2t-trainer --registry_help
23+
- mkdir $T2T_DATA_DIR
24+
- mkdir $T2T_TRAIN_DIR
25+
- t2t-datagen --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR
26+
- t2t-trainer --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --train_steps=5 --eval_steps=5 --output_dir=$T2T_TRAIN_DIR
27+
- t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR
1528
git:
16-
depth: 3
29+
depth: 3

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.2.3',
8+
version='1.2.4',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

tensor2tensor/bin/t2t-decoder

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def main(_):
7575

7676
hparams = trainer_utils.create_hparams(
7777
FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams)
78-
hparams = trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
78+
trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
7979
estimator, _ = trainer_utils.create_experiment_components(
8080
data_dir=data_dir,
8181
model_name=FLAGS.model,
@@ -90,9 +90,12 @@ def main(_):
9090
decoding.decode_from_file(estimator, FLAGS.decode_from_file, decode_hp,
9191
FLAGS.decode_to_file)
9292
else:
93-
decoding.decode_from_dataset(estimator,
94-
FLAGS.problems.split("-"), decode_hp,
95-
FLAGS.decode_to_file)
93+
decoding.decode_from_dataset(
94+
estimator,
95+
FLAGS.problems.split("-"),
96+
decode_hp,
97+
decode_to_file=FLAGS.decode_to_file,
98+
dataset_split="test" if FLAGS.eval_use_test_set else None)
9699

97100

98101
if __name__ == "__main__":

tensor2tensor/bin/t2t-trainer

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def main(_):
6868
trainer_utils.validate_flags()
6969
output_dir = os.path.expanduser(FLAGS.output_dir)
7070
tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
71+
if not FLAGS.data_dir:
72+
raise ValueError("You must specify a --data_dir")
7173
data_dir = os.path.expanduser(FLAGS.data_dir)
7274
tf.gfile.MakeDirs(output_dir)
7375

tensor2tensor/data_generators/algorithmic.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,15 @@ def num_shards(self):
6262
return 10
6363

6464
def generate_data(self, data_dir, _, task_id=-1):
65+
6566
def generator_eos(nbr_symbols, max_length, nbr_cases):
6667
"""Shift by NUM_RESERVED_IDS and append EOS token."""
6768
for case in self.generator(nbr_symbols, max_length, nbr_cases):
6869
new_case = {}
6970
for feature in case:
70-
new_case[feature] = [i + text_encoder.NUM_RESERVED_TOKENS
71-
for i in case[feature]] + [text_encoder.EOS_ID]
71+
new_case[feature] = [
72+
i + text_encoder.NUM_RESERVED_TOKENS for i in case[feature]
73+
] + [text_encoder.EOS_ID]
7274
yield new_case
7375

7476
utils.generate_dataset_and_shuffle(
@@ -154,10 +156,7 @@ def generator(self, nbr_symbols, max_length, nbr_cases):
154156
for _ in xrange(nbr_cases):
155157
l = np.random.randint(max_length) + 1
156158
inputs = [np.random.randint(nbr_symbols - shift) for _ in xrange(l)]
157-
yield {
158-
"inputs": inputs,
159-
"targets": [i + shift for i in inputs]
160-
}
159+
yield {"inputs": inputs, "targets": [i + shift for i in inputs]}
161160

162161
@property
163162
def dev_length(self):
@@ -191,10 +190,7 @@ def generator(self, nbr_symbols, max_length, nbr_cases):
191190
for _ in xrange(nbr_cases):
192191
l = np.random.randint(max_length) + 1
193192
inputs = [np.random.randint(nbr_symbols) for _ in xrange(l)]
194-
yield {
195-
"inputs": inputs,
196-
"targets": list(reversed(inputs))
197-
}
193+
yield {"inputs": inputs, "targets": list(reversed(inputs))}
198194

199195

200196
@registry.register_problem
@@ -272,10 +268,7 @@ def reverse_generator_nlplike(nbr_symbols,
272268
for _ in xrange(nbr_cases):
273269
l = int(abs(np.random.normal(loc=max_length / 2, scale=std_dev)) + 1)
274270
inputs = zipf_random_sample(distr_map, l)
275-
yield {
276-
"inputs": inputs,
277-
"targets": list(reversed(inputs))
278-
}
271+
yield {"inputs": inputs, "targets": list(reversed(inputs))}
279272

280273

281274
@registry.register_problem
@@ -287,8 +280,8 @@ def num_symbols(self):
287280
return 8000
288281

289282
def generator(self, nbr_symbols, max_length, nbr_cases):
290-
return reverse_generator_nlplike(
291-
nbr_symbols, max_length, nbr_cases, 10, 1.300)
283+
return reverse_generator_nlplike(nbr_symbols, max_length, nbr_cases, 10,
284+
1.300)
292285

293286
@property
294287
def train_length(self):
@@ -308,8 +301,8 @@ def num_symbols(self):
308301
return 32000
309302

310303
def generator(self, nbr_symbols, max_length, nbr_cases):
311-
return reverse_generator_nlplike(
312-
nbr_symbols, max_length, nbr_cases, 10, 1.050)
304+
return reverse_generator_nlplike(nbr_symbols, max_length, nbr_cases, 10,
305+
1.050)
313306

314307

315308
def lower_endian_to_number(l, base):
@@ -431,3 +424,28 @@ class AlgorithmicMultiplicationDecimal40(AlgorithmicMultiplicationBinary40):
431424
@property
432425
def num_symbols(self):
433426
return 10
427+
428+
429+
@registry.register_problem
430+
class AlgorithmicReverseBinary40Test(AlgorithmicReverseBinary40):
431+
"""Test Problem with tiny dataset."""
432+
433+
@property
434+
def train_length(self):
435+
return 10
436+
437+
@property
438+
def dev_length(self):
439+
return 10
440+
441+
@property
442+
def train_size(self):
443+
return 1000
444+
445+
@property
446+
def dev_size(self):
447+
return 100
448+
449+
@property
450+
def num_shards(self):
451+
return 1
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for Tensor2Tensor's all_problems.py."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
# Dependency imports
23+
from tensor2tensor.data_generators import all_problems
24+
25+
import tensorflow as tf
26+
27+
28+
class AllProblemsTest(tf.test.TestCase):
29+
30+
def testImport(self):
31+
"""Make sure that importing all_problems doesn't break."""
32+
self.assertIsNotNone(all_problems)
33+
34+
35+
if __name__ == '__main__':
36+
tf.test.main()

tensor2tensor/data_generators/image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ def generator(self, data_dir, tmp_dir, is_training):
650650
class ImageCifar10Plain(ImageCifar10):
651651

652652
def preprocess_example(self, example, mode, unused_hparams):
653+
example["inputs"] = tf.to_int64(example["inputs"])
653654
return example
654655

655656

tensor2tensor/data_generators/lm1b.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
import tensorflow as tf
3838

39-
4039
# End-of-sentence marker (should correspond to the position of EOS in the
4140
# RESERVED_TOKENS list in text_encoder.py)
4241
EOS = 1
@@ -59,9 +58,10 @@ def _original_vocab(tmp_dir):
5958
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
6059
if not os.path.exists(vocab_filepath):
6160
generator_utils.maybe_download(tmp_dir, vocab_filename, vocab_url)
62-
return set(
63-
[text_encoder.native_to_unicode(l.strip()) for l in
64-
tf.gfile.Open(vocab_filepath)])
61+
return set([
62+
text_encoder.native_to_unicode(l.strip())
63+
for l in tf.gfile.Open(vocab_filepath)
64+
])
6565

6666

6767
def _replace_oov(original_vocab, line):
@@ -81,19 +81,19 @@ def _replace_oov(original_vocab, line):
8181

8282

8383
def _train_data_filenames(tmp_dir):
84-
return [os.path.join(
85-
tmp_dir,
86-
"1-billion-word-language-modeling-benchmark-r13output",
87-
"training-monolingual.tokenized.shuffled",
88-
"news.en-%05d-of-00100" % i) for i in xrange(1, 100)]
84+
return [
85+
os.path.join(tmp_dir,
86+
"1-billion-word-language-modeling-benchmark-r13output",
87+
"training-monolingual.tokenized.shuffled",
88+
"news.en-%05d-of-00100" % i) for i in xrange(1, 100)
89+
]
8990

9091

9192
def _dev_data_filename(tmp_dir):
92-
return os.path.join(
93-
tmp_dir,
94-
"1-billion-word-language-modeling-benchmark-r13output",
95-
"heldout-monolingual.tokenized.shuffled",
96-
"news.en.heldout-00000-of-00050")
93+
return os.path.join(tmp_dir,
94+
"1-billion-word-language-modeling-benchmark-r13output",
95+
"heldout-monolingual.tokenized.shuffled",
96+
"news.en.heldout-00000-of-00050")
9797

9898

9999
def _maybe_download_corpus(tmp_dir):
@@ -112,17 +112,18 @@ def _maybe_download_corpus(tmp_dir):
112112
corpus_tar.extractall(tmp_dir)
113113

114114

115-
def _get_or_build_subword_text_encoder(tmp_dir):
115+
def _get_or_build_subword_text_encoder(tmp_dir, vocab_filepath):
116116
"""Builds a SubwordTextEncoder based on the corpus.
117117
118118
Args:
119119
tmp_dir: directory containing dataset.
120+
vocab_filepath: path to store (or load) vocab.
121+
120122
Returns:
121123
a SubwordTextEncoder.
122124
"""
123-
filepath = os.path.join(tmp_dir, "lm1b_32k.subword_text_encoder")
124-
if tf.gfile.Exists(filepath):
125-
return text_encoder.SubwordTextEncoder(filepath)
125+
if tf.gfile.Exists(vocab_filepath):
126+
return text_encoder.SubwordTextEncoder(vocab_filepath)
126127
_maybe_download_corpus(tmp_dir)
127128
original_vocab = _original_vocab(tmp_dir)
128129
token_counts = defaultdict(int)
@@ -138,7 +139,7 @@ def _get_or_build_subword_text_encoder(tmp_dir):
138139
break
139140
ret = text_encoder.SubwordTextEncoder()
140141
ret.build_from_token_counts(token_counts, min_count=5)
141-
ret.store_to_file(filepath)
142+
ret.store_to_file(vocab_filepath)
142143
return ret
143144

144145

@@ -152,7 +153,7 @@ def is_character_level(self):
152153

153154
@property
154155
def has_inputs(self):
155-
return True
156+
return False
156157

157158
@property
158159
def input_space_id(self):
@@ -184,25 +185,26 @@ def targeted_vocab_size(self):
184185
def use_train_shards_for_dev(self):
185186
return True
186187

187-
def generator(self, tmp_dir, train, characters=False):
188+
def generator(self, data_dir, tmp_dir, is_training):
188189
"""Generator for lm1b sentences.
189190
190191
Args:
191-
tmp_dir: a string.
192-
train: a boolean.
193-
characters: a boolean
192+
data_dir: data dir.
193+
tmp_dir: tmp dir.
194+
is_training: a boolean.
194195
195196
Yields:
196197
A dictionary {"inputs": [0], "targets": [<subword ids>]}
197198
"""
198199
_maybe_download_corpus(tmp_dir)
199200
original_vocab = _original_vocab(tmp_dir)
200-
files = (_train_data_filenames(tmp_dir) if train
201-
else [_dev_data_filename(tmp_dir)])
202-
if characters:
201+
files = (_train_data_filenames(tmp_dir)
202+
if is_training else [_dev_data_filename(tmp_dir)])
203+
if self.is_character_level:
203204
encoder = text_encoder.ByteTextEncoder()
204205
else:
205-
encoder = _get_or_build_subword_text_encoder(tmp_dir)
206+
vocab_filepath = os.path.join(data_dir, self.vocab_file)
207+
encoder = _get_or_build_subword_text_encoder(tmp_dir, vocab_filepath)
206208
for filepath in files:
207209
tf.logging.info("filepath = %s", filepath)
208210
for line in tf.gfile.Open(filepath):

0 commit comments

Comments
 (0)