Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 9625f34

Browse files
author
Mesh TensorFlow Team
committed
This quick change folds the additional pf dimension in the topology into the 1st mesh tf dimension.
Folding the additional pf dimension into the 1st and 2nd mesh tf dimensions is equivalent w.r.t. examples and global_step per sec. PiperOrigin-RevId: 355756187
1 parent 0e7ce98 commit 9625f34

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

mesh_tensorflow/simd_mesh_impl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -826,9 +826,9 @@ def physical_shape_3d_from_topology_proto_4d(mesh_shape):
826826
Returns:
827827
a list of length 3
828828
"""
829-
if len(mesh_shape) != 4 or mesh_shape[2] != 1:
830-
raise ValueError("Expected a 4d shape [x, y, 1, core]")
831-
return [mesh_shape[1], mesh_shape[0], mesh_shape[3]]
829+
if len(mesh_shape) != 4:
830+
raise ValueError("Expected a 4d shape [x, y, z, core]")
831+
return [mesh_shape[1]*mesh_shape[2], mesh_shape[0], mesh_shape[3]]
832832

833833

834834
def auto_logical_to_physical_tpu(logical_shape,

mesh_tensorflow/transformer/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ def tpu_mesh_shape(tpu_topology=gin.REQUIRED,
255255
if tpu_topology.startswith("v"):
256256
num_cores = int(tpu_topology.split("-")[-1])
257257
else:
258-
x, y = tpu_topology.split("x")
259-
num_cores = int(x) * int(y) * 2
258+
tpu_dim = [int(x) for x in tpu_topology.split("x")]
259+
num_cores = functools.reduce(lambda x, y: x*y, tpu_dim) * 2
260260
if isinstance(model_parallelism, list):
261261
# model_parallelism is actually a spec used to
262262
# construct a simd_mesh_impl.HierarchicalTiling object

0 commit comments

Comments
 (0)