@@ -183,7 +183,7 @@ def _get_tid(node):
183183 return 0
184184
185185
186- def get_repr (arg ):
186+ def get_repr (arg , mode = "full" ):
187187 def get_dtype_repr (dtype ):
188188 return dtype_abbrs [dtype ]
189189
@@ -200,20 +200,20 @@ def get_dtype_repr(dtype):
200200 return get_dtype_repr (arg )
201201
202202 if isinstance (arg , torch .fx .Node ):
203- if "val" not in arg .meta :
204- return f"fx node { arg } "
205-
206- return get_repr (arg .meta ["val" ])
203+ if mode == "name_only" or "val" not in arg .meta :
204+ return f"fx node { arg .name } "
205+ elif mode == "full" :
206+ return {"name" : arg .name , "data" : get_repr (arg .meta ["val" ])}
207+ elif mode == "content_only" :
208+ return get_repr (arg .meta ["val" ])
209+ else :
210+ raise ValueError (f"Unknown mode { mode } " )
207211
208212 if isinstance (arg , (list , tuple )):
209- # TODO: make better repr that don't blow up
210- # for long lists
211- return [get_repr (x ) for x in arg ]
213+ return [get_repr (x , mode = "name_only" ) for x in arg ]
212214
213215 if isinstance (arg , dict ):
214- # TODO: make better repr that don't blow up
215- # for long lists
216- return {k : get_repr (v ) for k , v in arg .items ()}
216+ return {k : get_repr (v , mode = "name_only" ) for k , v in arg .items ()}
217217
218218 return f"arg { type (arg )} "
219219
@@ -239,7 +239,7 @@ def create_execution_trace(
239239 curr_time [tid ] = curr_time [0 ]
240240 event = {"ph" : "X" , "cat" : "kernel" , "name" : str (node ), "pid" : 0 , "tid" : tid }
241241 if _is_communication_node (node ):
242- if tid == 0 and is_wait_tensor (node ):
242+ if tid == 0 and is_wait_tensor (node ) and node . args [ 0 ]. op != "placeholder" :
243243 # if it's wait tensor, let's sync with compute stream
244244 comm_end_time = global_time .pop (node .args [0 ])
245245 curr_time [tid ] = max (curr_time [tid ], comm_end_time )
@@ -258,7 +258,7 @@ def create_execution_trace(
258258 args : dict [str , Any ] = {}
259259 args ["order" ] = node_idx
260260
261- args ["output" ] = get_repr (node )
261+ args ["output" ] = get_repr (node , mode = "content_only" )
262262 node_args = []
263263 for arg in node .args :
264264 node_args .append (get_repr (arg ))
0 commit comments