@@ -183,12 +183,6 @@ def readArgs(self, parser):
183183 default = False ,
184184 help = "Select some ranks to send/receive 0B messages" ,
185185 )
186- parser .add_argument (
187- "--use-device-time" ,
188- action = "store_true" ,
189- default = False ,
190- help = "use device time measurement" ,
191- )
192186 parser .add_argument (
193187 "--graph-launches" ,
194188 type = int ,
@@ -368,7 +362,9 @@ def run_coll_cuda_graph(self, comm_fn=None, dcheck=False):
368362 self .backendFuncs .sync_barrier (
369363 self .collectiveArgs , desc = "run_coll_cuda_graph_begin"
370364 )
371- elapsedTimeNS = 0.0
365+ elapsedCPUTimeNS = 0.0
366+ start_event = self .backendFuncs .create_event (self .collectiveArgs )
367+ end_event = self .backendFuncs .create_event (self .collectiveArgs )
372368
373369 # 1. Warmup phase
374370 # launch collective on a separate stream and sync with current_stream
@@ -393,21 +389,34 @@ def run_coll_cuda_graph(self, comm_fn=None, dcheck=False):
393389
394390 # 3. Replay
395391 start = time .monotonic () # available only in py3
392+ self .backendFuncs .record_event (start_event , self .collectiveArgs )
396393 for _ in range (self .collectiveArgs .graph_launches ):
397394 if self .collectiveArgs .enable_profiler :
398395 comms_utils .sampleProfiler ()
399396
400397 # [optional] we can feed new input data to ipTensor for each replay
401398 g .replay ()
402399
400+ self .backendFuncs .record_event (end_event , self .collectiveArgs )
403401 self .backendFuncs .complete_accel_ops (self .collectiveArgs )
402+
404403 end = time .monotonic () # available only in py3
405404
406405 ensureTensorFlush (self .collectiveArgs .opTensor )
407406
408- elapsedTimeNS += (
407+ elapsedCPUTimeNS += (
409408 end - start
410409 ) * 1e9 # keeping time in NS, helps in divising data by nanoseconds
410+ elapsedDeviceTimeMs = self .backendFuncs .elapsed_time (start_event , end_event )
411+ elapsedDeviceTimeNS = elapsedDeviceTimeMs * 1e6
412+ elapsedTimeNS = (
413+ elapsedDeviceTimeNS
414+ if self .collectiveArgs .use_device_time
415+ else elapsedCPUTimeNS
416+ )
417+ logger .debug (
418+ f"elapsedCPUTimeNS={ elapsedCPUTimeNS } , elapsedDeviceTimeNS={ elapsedDeviceTimeNS } ."
419+ )
411420
412421 memSize = self .backendFuncs .get_mem_size (self .collectiveArgs )
413422
@@ -436,17 +445,11 @@ def run_coll_cuda_graph(self, comm_fn=None, dcheck=False):
436445 }
437446 return results
438447
439- def runColl (self , comm_fn = None , dcheck = False ):
440- if self .collectiveArgs .graph_launches > 0 :
441- return self .run_coll_cuda_graph (comm_fn , dcheck )
448+ def run_coll_non_graph (self , comm_fn = None , dcheck = False ):
442449 self .backendFuncs .sync_barrier (self .collectiveArgs , desc = "runColl_begin" )
443450
444- elapsedCPUTimeNS = 0.0
445- elapsedDeviceTimeNS = 0.0
451+ elapsedTimeNS = 0.0
446452 is_blocking = not self .collectiveArgs .asyncOp
447- # Initialize CUDA events for device timing
448- start_event = self .backendFuncs .create_event (self .collectiveArgs )
449- end_event = self .backendFuncs .create_event (self .collectiveArgs )
450453
451454 for nIter in range (
452455 self .collectiveArgs .numWarmupIters + self .collectiveArgs .numIters
@@ -458,22 +461,16 @@ def runColl(self, comm_fn=None, dcheck=False):
458461 self .backendFuncs .complete_accel_ops (self .collectiveArgs )
459462 ensureTensorFlush (self .collectiveArgs .opTensor )
460463 # Start measuring time after warmup iterations
461- elapsedCPUTimeNS = 0.0
462- elapsedDeviceTimeNS = 0.0
464+ elapsedTimeNS = 0.0
463465 self .collectiveArgs .quant_time .reset ()
464466 self .collectiveArgs .dequant_time .reset ()
465- self .backendFuncs .record_event (
466- start_event , self .collectiveArgs
467- ) # record start event for non-blocking operation
468467 # reset tensor values for data validation check
469468 if dcheck :
470469 self .setTensorVal (self .collectiveArgs .opTensor )
471470 # for blocking mode, do barrier before starting collective
472471 if is_blocking :
473472 self .backendFuncs .sync_barrier (self .collectiveArgs )
474- self .backendFuncs .record_event (
475- start_event , self .collectiveArgs
476- ) # record start event for blocking operation
473+
477474 start = time .monotonic () # available only in py3
478475 with paramStreamGuard (
479476 stream = self .backendFuncs .get_current_stream (
@@ -488,47 +485,29 @@ def runColl(self, comm_fn=None, dcheck=False):
488485 ]
489486 for _ in range (self .collectiveArgs .numCollPerIter ):
490487 comm_fn (self .collectiveArgs )
488+
491489 if is_blocking : # should be sychronous, wait for the collective
492- self .backendFuncs .record_event (
493- end_event , self .collectiveArgs
494- ) # record end event for blocking operation
495490 self .backendFuncs .complete_accel_ops (self .collectiveArgs )
496- elapsedDeviceTimeMs = self .backendFuncs .elapsed_time (
497- start_event , end_event
498- )
499- elapsedDeviceTimeNS += elapsedDeviceTimeMs * 1e6 # Convert ms to ns
491+
500492 # Measuring time.
501- elapsedCPUTimeNS += (
493+ elapsedTimeNS += (
502494 time .monotonic () - start
503495 ) * 1e9 # keeping time in NS, helps in divising data by nanosecond
496+
504497 start = time .monotonic () # available only in py3
505- # if not blocking, record second end event here
506- if not is_blocking :
507- self .backendFuncs .record_event (
508- end_event , self .collectiveArgs
509- ) # record end event for non-blocking operations
510498 self .backendFuncs .complete_accel_ops (self .collectiveArgs )
511499 end = time .monotonic () # available only in py3
500+
512501 ensureTensorFlush (self .collectiveArgs .opTensor )
513502
514- elapsedCPUTimeNS += (
503+ elapsedTimeNS += (
515504 end - start
516505 ) * 1e9 # keeping time in NS, helps in divising data by nanoseconds
517- if not is_blocking :
518- elapsedDeviceTimeMs = self .backendFuncs .elapsed_time (start_event , end_event )
519- elapsedDeviceTimeNS = elapsedDeviceTimeMs * 1e6 # Convert ms to ns
520506
521507 memSize = self .backendFuncs .get_mem_size (self .collectiveArgs )
522- logger .debug (
523- f"elapsedCPUTimeNS={ elapsedCPUTimeNS } , elapsedDeviceTimeNS={ elapsedDeviceTimeNS } ."
524- )
525- ElapsedTimeNS = (
526- elapsedDeviceTimeNS
527- if self .collectiveArgs .use_device_time
528- else elapsedCPUTimeNS
529- )
508+
530509 avgIterNS , algBW = comms_utils .getAlgBW (
531- ElapsedTimeNS ,
510+ elapsedTimeNS ,
532511 memSize ,
533512 self .collectiveArgs .numIters * self .collectiveArgs .numCollPerIter ,
534513 )
@@ -550,6 +529,13 @@ def runColl(self, comm_fn=None, dcheck=False):
550529 }
551530 return results
552531
532+ def runColl (self , comm_fn = None , dcheck = False ):
533+ return (
534+ self .run_coll_non_graph (comm_fn , dcheck )
535+ if self .collectiveArgs .graph_launches == 0
536+ else self .run_coll_cuda_graph (comm_fn , dcheck )
537+ )
538+
553539 def runPt2Pt (self ):
554540 self .backendFuncs .sync_barrier (self .collectiveArgs )
555541 # warm-up
@@ -884,7 +870,6 @@ def initCollectiveArgs(self, commsParams):
884870 self .collectiveArgs .asyncOp = False if commsParams .blockingFlag == 1 else True
885871 self .collectiveArgs .numCollPerIter = commsParams .num_coll
886872 self .collectiveArgs .include_0B = commsParams .include_0B
887- self .collectiveArgs .use_device_time = commsParams .use_device_time
888873 self .collectiveArgs .graph_launches = commsParams .graph_launches
889874
890875 if commsParams .bitwidth < 32 :
@@ -911,11 +896,7 @@ def gatherBenchTime(self, collectiveArgs, commsParams, timeUsElapsedList):
911896 # Push the list to device, then do an all-gather.
912897 timeElapsedTensor = torch .tensor (
913898 timeUsElapsedList ,
914- device = (
915- self .backendFuncs .get_device ()
916- if commsParams .backend == "nccl"
917- else torch .device ("cpu" )
918- ),
899+ device = (self .backendFuncs .get_device ()),
919900 )
920901 collectiveArgs .opTensor = None
921902 if commsParams .backend != "xla" :
@@ -1051,10 +1032,7 @@ def reportBenchTime(
10511032 dequantTimeTensorList ,
10521033 ):
10531034 # convernt num_elements to # of elements per rank
1054- if commsParams .collective in (
1055- "all_to_all" ,
1056- "all_to_allv" ,
1057- "all_to_all_single" ,
1035+ if "all_to_all" in commsParams .collective or commsParams .collective in (
10581036 "reduce_scatter" ,
10591037 "reduce_scatter_v" ,
10601038 "reduce_scatter_base" ,
0 commit comments