Skip to content

Commit 06eff81

Browse files
committed
Update debug_util to update index when calling sync
1 parent b59160d commit 06eff81

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

test/debug_tool/test_pt_xla_debug.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,18 @@ def toy_program(t1):
177177

178178
if self.debug_level > 1:
179179
self.assertEqual(len(executation_causes), 2)
180+
self.assertIn(
181+
'torch_xla.compile clear the pending graph prior calling the target function',
182+
executation_causes[0])
183+
self.assertIn('torch_xla.compile\n', executation_causes[1])
180184
else:
181185
self.assertEqual(len(executation_causes), 0)
182186

183187
self.assertEqual(len(compilation_causes), 2)
188+
self.assertIn(
189+
'torch_xla.compile clear the pending graph prior calling the target function',
190+
compilation_causes[0])
191+
self.assertIn('torch_xla.compile\n', compilation_causes[1])
184192

185193
if self.debug_level > 1:
186194
# one graph info from compilation, one from execution, hash should match

torch_xla/csrc/debug_util.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -327,24 +327,27 @@ void DebugUtil::analyze_graph_execution_python_frame(
327327
} else if (frames[0].function == "mark_step" ||
328328
(frames[0].function == "sync" &&
329329
endsWith(frames[0].file, "torch_xla.py"))) {
330-
if (frames[1].function == "next" &&
331-
endsWith(frames[1].file, "parallel_loader.py")) {
330+
bool called_by_sync =
331+
frames[0].function == "mark_step" && frames[1].function == "sync";
332+
int i = called_by_sync ? 2 : 1;
333+
if (frames[i].function == "next" &&
334+
endsWith(frames[i].file, "parallel_loader.py")) {
332335
ss << debug_output_prefix
333336
<< " mark_step in parallel loader at step end\n";
334-
} else if (frames[1].function == "__exit__" &&
335-
endsWith(frames[1].file, "profiler.py")) {
337+
} else if (frames[i].function == "__exit__" &&
338+
endsWith(frames[i].file, "profiler.py")) {
336339
ss << debug_output_prefix
337340
<< " mark_step when exiting a profiler StepTrace region\n";
338-
} else if ((frames[1].function == "extract_compiled_graph_helper" ||
339-
frames[1].function == "extract_internal") &&
340-
endsWith(frames[1].file, "dynamo_bridge.py")) {
341+
} else if ((frames[i].function == "extract_compiled_graph_helper" ||
342+
frames[i].function == "extract_internal") &&
343+
endsWith(frames[i].file, "dynamo_bridge.py")) {
341344
ss << debug_output_prefix
342345
<< " mark_step when dynamo processing input graphs\n";
343-
} else if (frames[1].function == "_compile" &&
344-
endsWith(frames[1].file, "torch_xla.py")) {
346+
} else if (frames[i].function == "_compile" &&
347+
endsWith(frames[i].file, "torch_xla.py")) {
345348
ss << debug_output_prefix << " torch_xla.compile\n";
346-
} else if (frames[1].function == "_clear_pending_ops_before_compile" &&
347-
endsWith(frames[1].file, "torch_xla.py")) {
349+
} else if (frames[i].function == "_clear_pending_ops_before_compile" &&
350+
endsWith(frames[i].file, "torch_xla.py")) {
348351
ss << debug_output_prefix
349352
<< " torch_xla.compile clear the pending graph prior calling the "
350353
"target function\n";

0 commit comments

Comments
 (0)