@@ -882,7 +882,7 @@ def _shard_features(self, features): # pylint: disable=missing-docstring
882
882
v = tf .expand_dims (v , axis = - 1 )
883
883
v_shape = [1 ]
884
884
if v_shape == [1 ]:
885
- v = tf .tile (v , [self ._num_datashards ])
885
+ v = tf .tile (v , tf . to_int32 ( [self ._num_datashards ]) )
886
886
sharded_features [k ] = self ._data_parallelism (
887
887
tf .identity , tf .split (v , self ._num_datashards , 0 ))
888
888
return sharded_features
@@ -1288,17 +1288,17 @@ def _create_host_call(model_dir):
1288
1288
graph = tf .get_default_graph ()
1289
1289
summaries = graph .get_collection (tf .GraphKeys .SUMMARIES )
1290
1290
1291
- gs_t = tf .reshape (tf .train .get_global_step (), [1 ])
1291
+ gs_t = tf .reshape (tf .to_int32 ( tf . train .get_global_step () ), [1 ])
1292
1292
summary_kwargs = collections .OrderedDict ()
1293
1293
for t in summaries :
1294
1294
if t .op .type != "ScalarSummary" :
1295
1295
continue
1296
1296
1297
1297
name = t .op .name
1298
1298
tensor = t .op .inputs [1 ]
1299
- assert tensor .shape .is_compatible_with (
1300
- []), ( "ScalarSummary %s must have shape [], but is: %s." %
1301
- ( name , tensor . shape ) )
1299
+ assert tensor .shape .is_compatible_with ([])
1300
+ if tensor . dtype == tf . int64 :
1301
+ tensor = tf . to_int32 ( tensor )
1302
1302
summary_kwargs [name ] = tf .reshape (tensor , [1 ])
1303
1303
summary_kwargs ["global_step" ] = gs_t
1304
1304
@@ -1312,7 +1312,7 @@ def host_call_fn(**kwargs):
1312
1312
Returns:
1313
1313
List of summary ops to run on the CPU host.
1314
1314
"""
1315
- gs = kwargs .pop ("global_step" )[0 ]
1315
+ gs = tf . to_int64 ( kwargs .pop ("global_step" )[0 ])
1316
1316
with tf .contrib .summary .create_file_writer (model_dir ).as_default ():
1317
1317
with tf .contrib .summary .always_record_summaries ():
1318
1318
for name , value in sorted (six .iteritems (kwargs )):
0 commit comments