Skip to content

Commit c3fd25b

Browse files
authored
Improve repr for execution trace (#270)
Also handle wait_tensor nodes which inputs are placeholder nodes
1 parent aa4d15e commit c3fd25b

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

autoparallel/debug_helpers.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _get_tid(node):
183183
return 0
184184

185185

186-
def get_repr(arg):
186+
def get_repr(arg, mode="full"):
187187
def get_dtype_repr(dtype):
188188
return dtype_abbrs[dtype]
189189

@@ -200,20 +200,20 @@ def get_dtype_repr(dtype):
200200
return get_dtype_repr(arg)
201201

202202
if isinstance(arg, torch.fx.Node):
203-
if "val" not in arg.meta:
204-
return f"fx node {arg}"
205-
206-
return get_repr(arg.meta["val"])
203+
if mode == "name_only" or "val" not in arg.meta:
204+
return f"fx node {arg.name}"
205+
elif mode == "full":
206+
return {"name": arg.name, "data": get_repr(arg.meta["val"])}
207+
elif mode == "content_only":
208+
return get_repr(arg.meta["val"])
209+
else:
210+
raise ValueError(f"Unknown mode {mode}")
207211

208212
if isinstance(arg, (list, tuple)):
209-
# TODO: make better repr that don't blow up
210-
# for long lists
211-
return [get_repr(x) for x in arg]
213+
return [get_repr(x, mode="name_only") for x in arg]
212214

213215
if isinstance(arg, dict):
214-
# TODO: make better repr that don't blow up
215-
# for long lists
216-
return {k: get_repr(v) for k, v in arg.items()}
216+
return {k: get_repr(v, mode="name_only") for k, v in arg.items()}
217217

218218
return f"arg {type(arg)}"
219219

@@ -239,7 +239,7 @@ def create_execution_trace(
239239
curr_time[tid] = curr_time[0]
240240
event = {"ph": "X", "cat": "kernel", "name": str(node), "pid": 0, "tid": tid}
241241
if _is_communication_node(node):
242-
if tid == 0 and is_wait_tensor(node):
242+
if tid == 0 and is_wait_tensor(node) and node.args[0].op != "placeholder":
243243
# if it's wait tensor, let's sync with compute stream
244244
comm_end_time = global_time.pop(node.args[0])
245245
curr_time[tid] = max(curr_time[tid], comm_end_time)
@@ -258,7 +258,7 @@ def create_execution_trace(
258258
args: dict[str, Any] = {}
259259
args["order"] = node_idx
260260

261-
args["output"] = get_repr(node)
261+
args["output"] = get_repr(node, mode="content_only")
262262
node_args = []
263263
for arg in node.args:
264264
node_args.append(get_repr(arg))

0 commit comments

Comments
 (0)