Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 3d0e1ef

Browse files
author
Ryan Sepassi
committed
Switch int64 to int32 in a few places. TF 1.6 TPU support for int64 is spotty.
PiperOrigin-RevId: 197472099
1 parent 56643b1 commit 3d0e1ef

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tensor2tensor/utils/t2t_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,7 @@ def _shard_features(self, features): # pylint: disable=missing-docstring
882882
v = tf.expand_dims(v, axis=-1)
883883
v_shape = [1]
884884
if v_shape == [1]:
885-
v = tf.tile(v, [self._num_datashards])
885+
v = tf.tile(v, tf.to_int32([self._num_datashards]))
886886
sharded_features[k] = self._data_parallelism(
887887
tf.identity, tf.split(v, self._num_datashards, 0))
888888
return sharded_features
@@ -1288,17 +1288,17 @@ def _create_host_call(model_dir):
12881288
graph = tf.get_default_graph()
12891289
summaries = graph.get_collection(tf.GraphKeys.SUMMARIES)
12901290

1291-
gs_t = tf.reshape(tf.train.get_global_step(), [1])
1291+
gs_t = tf.reshape(tf.to_int32(tf.train.get_global_step()), [1])
12921292
summary_kwargs = collections.OrderedDict()
12931293
for t in summaries:
12941294
if t.op.type != "ScalarSummary":
12951295
continue
12961296

12971297
name = t.op.name
12981298
tensor = t.op.inputs[1]
1299-
assert tensor.shape.is_compatible_with(
1300-
[]), ("ScalarSummary %s must have shape [], but is: %s." %
1301-
(name, tensor.shape))
1299+
assert tensor.shape.is_compatible_with([])
1300+
if tensor.dtype == tf.int64:
1301+
tensor = tf.to_int32(tensor)
13021302
summary_kwargs[name] = tf.reshape(tensor, [1])
13031303
summary_kwargs["global_step"] = gs_t
13041304

@@ -1312,7 +1312,7 @@ def host_call_fn(**kwargs):
13121312
Returns:
13131313
List of summary ops to run on the CPU host.
13141314
"""
1315-
gs = kwargs.pop("global_step")[0]
1315+
gs = tf.to_int64(kwargs.pop("global_step")[0])
13161316
with tf.contrib.summary.create_file_writer(model_dir).as_default():
13171317
with tf.contrib.summary.always_record_summaries():
13181318
for name, value in sorted(six.iteritems(kwargs)):

0 commit comments

Comments
 (0)