Skip to content

Commit cd5528e

Browse files
ycui1984facebook-github-bot
authored andcommitted
Use device time to measure latency instead of cpu
Summary: As titled Reviewed By: cenzhaometa Differential Revision: D71828892 fbshipit-source-id: 3f02962f3aa8c035b5d4c38c33954d6bbebfe362
1 parent 9c8907f commit cd5528e

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

train/comms/pt/comms.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,9 @@ def run_coll_cuda_graph(self, comm_fn=None, dcheck=False):
368368
self.backendFuncs.sync_barrier(
369369
self.collectiveArgs, desc="run_coll_cuda_graph_begin"
370370
)
371-
elapsedTimeNS = 0.0
371+
elapsedCPUTimeNS = 0.0
372+
start_event = self.backendFuncs.create_event(self.collectiveArgs)
373+
end_event = self.backendFuncs.create_event(self.collectiveArgs)
372374

373375
# 1. Warmup phase
374376
# launch collective on a separate stream and sync with current_stream
@@ -393,21 +395,34 @@ def run_coll_cuda_graph(self, comm_fn=None, dcheck=False):
393395

394396
# 3. Replay
395397
start = time.monotonic() # available only in py3
398+
self.backendFuncs.record_event(start_event, self.collectiveArgs)
396399
for _ in range(self.collectiveArgs.graph_launches):
397400
if self.collectiveArgs.enable_profiler:
398401
comms_utils.sampleProfiler()
399402

400403
# [optional] we can feed new input data to ipTensor for each replay
401404
g.replay()
402405

406+
self.backendFuncs.record_event(end_event, self.collectiveArgs)
403407
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
408+
404409
end = time.monotonic() # available only in py3
405410

406411
ensureTensorFlush(self.collectiveArgs.opTensor)
407412

408-
elapsedTimeNS += (
413+
elapsedCPUTimeNS += (
409414
end - start
410415
) * 1e9 # keeping time in NS, helps in divising data by nanoseconds
416+
elapsedDeviceTimeMs = self.backendFuncs.elapsed_time(start_event, end_event)
417+
elapsedDeviceTimeNS = elapsedDeviceTimeMs * 1e6
418+
elapsedTimeNS = (
419+
elapsedDeviceTimeNS
420+
if self.collectiveArgs.use_device_time
421+
else elapsedCPUTimeNS
422+
)
423+
logger.debug(
424+
f"elapsedCPUTimeNS={elapsedCPUTimeNS}, elapsedDeviceTimeNS={elapsedDeviceTimeNS}."
425+
)
411426

412427
memSize = self.backendFuncs.get_mem_size(self.collectiveArgs)
413428

0 commit comments

Comments
 (0)