@@ -183,12 +183,6 @@ def readArgs(self, parser):
183183 default = False ,
184184 help = "Select some ranks to send/receive 0B messages" ,
185185 )
186- parser .add_argument (
187- "--use-device-time" ,
188- action = "store_true" ,
189- default = False ,
190- help = "use device time measurement" ,
191- )
192186 parser .add_argument (
193187 "--graph-launches" ,
194188 type = int ,
@@ -454,12 +448,8 @@ def run_coll_cuda_graph(self, comm_fn=None, dcheck=False):
454448 def run_coll_non_graph (self , comm_fn = None , dcheck = False ):
455449 self .backendFuncs .sync_barrier (self .collectiveArgs , desc = "runColl_begin" )
456450
457- elapsedCPUTimeNS = 0.0
458- elapsedDeviceTimeNS = 0.0
451+ elapsedTimeNS = 0.0
459452 is_blocking = not self .collectiveArgs .asyncOp
460- # Initialize CUDA events for device timing
461- start_event = self .backendFuncs .create_event (self .collectiveArgs )
462- end_event = self .backendFuncs .create_event (self .collectiveArgs )
463453
464454 for nIter in range (
465455 self .collectiveArgs .numWarmupIters + self .collectiveArgs .numIters
@@ -471,22 +461,16 @@ def run_coll_non_graph(self, comm_fn=None, dcheck=False):
471461 self .backendFuncs .complete_accel_ops (self .collectiveArgs )
472462 ensureTensorFlush (self .collectiveArgs .opTensor )
473463 # Start measuring time after warmup iterations
474- elapsedCPUTimeNS = 0.0
475- elapsedDeviceTimeNS = 0.0
464+ elapsedTimeNS = 0.0
476465 self .collectiveArgs .quant_time .reset ()
477466 self .collectiveArgs .dequant_time .reset ()
478- self .backendFuncs .record_event (
479- start_event , self .collectiveArgs
480- ) # record start event for non-blocking operation
481467 # reset tensor values for data validation check
482468 if dcheck :
483469 self .setTensorVal (self .collectiveArgs .opTensor )
484470 # for blocking mode, do barrier before starting collective
485471 if is_blocking :
486472 self .backendFuncs .sync_barrier (self .collectiveArgs )
487- self .backendFuncs .record_event (
488- start_event , self .collectiveArgs
489- ) # record start event for blocking operation
473+
490474 start = time .monotonic () # available only in py3
491475 with paramStreamGuard (
492476 stream = self .backendFuncs .get_current_stream (
@@ -501,47 +485,29 @@ def run_coll_non_graph(self, comm_fn=None, dcheck=False):
501485 ]
502486 for _ in range (self .collectiveArgs .numCollPerIter ):
503487 comm_fn (self .collectiveArgs )
488+
504489 if is_blocking : # should be sychronous, wait for the collective
505- self .backendFuncs .record_event (
506- end_event , self .collectiveArgs
507- ) # record end event for blocking operation
508490 self .backendFuncs .complete_accel_ops (self .collectiveArgs )
509- elapsedDeviceTimeMs = self .backendFuncs .elapsed_time (
510- start_event , end_event
511- )
512- elapsedDeviceTimeNS += elapsedDeviceTimeMs * 1e6 # Convert ms to ns
491+
513492 # Measuring time.
514- elapsedCPUTimeNS += (
493+ elapsedTimeNS += (
515494 time .monotonic () - start
516495 ) * 1e9 # keeping time in NS, helps in divising data by nanosecond
496+
517497 start = time .monotonic () # available only in py3
518- # if not blocking, record second end event here
519- if not is_blocking :
520- self .backendFuncs .record_event (
521- end_event , self .collectiveArgs
522- ) # record end event for non-blocking operations
523498 self .backendFuncs .complete_accel_ops (self .collectiveArgs )
524499 end = time .monotonic () # available only in py3
500+
525501 ensureTensorFlush (self .collectiveArgs .opTensor )
526502
527- elapsedCPUTimeNS += (
503+ elapsedTimeNS += (
528504 end - start
529505 ) * 1e9 # keeping time in NS, helps in divising data by nanoseconds
530- if not is_blocking :
531- elapsedDeviceTimeMs = self .backendFuncs .elapsed_time (start_event , end_event )
532- elapsedDeviceTimeNS = elapsedDeviceTimeMs * 1e6 # Convert ms to ns
533506
534507 memSize = self .backendFuncs .get_mem_size (self .collectiveArgs )
535- logger .debug (
536- f"elapsedCPUTimeNS={ elapsedCPUTimeNS } , elapsedDeviceTimeNS={ elapsedDeviceTimeNS } ."
537- )
538- ElapsedTimeNS = (
539- elapsedDeviceTimeNS
540- if self .collectiveArgs .use_device_time
541- else elapsedCPUTimeNS
542- )
508+
543509 avgIterNS , algBW = comms_utils .getAlgBW (
544- ElapsedTimeNS ,
510+ elapsedTimeNS ,
545511 memSize ,
546512 self .collectiveArgs .numIters * self .collectiveArgs .numCollPerIter ,
547513 )
@@ -904,7 +870,6 @@ def initCollectiveArgs(self, commsParams):
904870 self .collectiveArgs .asyncOp = False if commsParams .blockingFlag == 1 else True
905871 self .collectiveArgs .numCollPerIter = commsParams .num_coll
906872 self .collectiveArgs .include_0B = commsParams .include_0B
907- self .collectiveArgs .use_device_time = commsParams .use_device_time
908873 self .collectiveArgs .graph_launches = commsParams .graph_launches
909874
910875 if commsParams .bitwidth < 32 :
0 commit comments