Skip to content

Commit a437fce

Browse files
TonyTong999facebook-github-bot
authored andcommitted
backout D70007712
Summary: As titled Reviewed By: kingchc Differential Revision: D71912562 fbshipit-source-id: d230f1d7b94973ac036fcb682ba84288a571ad34
1 parent cd5528e commit a437fce

File tree

4 files changed

+11
-69
lines changed

4 files changed

+11
-69
lines changed

train/comms/pt/comms.py

Lines changed: 11 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,6 @@ def readArgs(self, parser):
183183
default=False,
184184
help="Select some ranks to send/receive 0B messages",
185185
)
186-
parser.add_argument(
187-
"--use-device-time",
188-
action="store_true",
189-
default=False,
190-
help="use device time measurement",
191-
)
192186
parser.add_argument(
193187
"--graph-launches",
194188
type=int,
@@ -454,12 +448,8 @@ def run_coll_cuda_graph(self, comm_fn=None, dcheck=False):
454448
def run_coll_non_graph(self, comm_fn=None, dcheck=False):
455449
self.backendFuncs.sync_barrier(self.collectiveArgs, desc="runColl_begin")
456450

457-
elapsedCPUTimeNS = 0.0
458-
elapsedDeviceTimeNS = 0.0
451+
elapsedTimeNS = 0.0
459452
is_blocking = not self.collectiveArgs.asyncOp
460-
# Initialize CUDA events for device timing
461-
start_event = self.backendFuncs.create_event(self.collectiveArgs)
462-
end_event = self.backendFuncs.create_event(self.collectiveArgs)
463453

464454
for nIter in range(
465455
self.collectiveArgs.numWarmupIters + self.collectiveArgs.numIters
@@ -471,22 +461,16 @@ def run_coll_non_graph(self, comm_fn=None, dcheck=False):
471461
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
472462
ensureTensorFlush(self.collectiveArgs.opTensor)
473463
# Start measuring time after warmup iterations
474-
elapsedCPUTimeNS = 0.0
475-
elapsedDeviceTimeNS = 0.0
464+
elapsedTimeNS = 0.0
476465
self.collectiveArgs.quant_time.reset()
477466
self.collectiveArgs.dequant_time.reset()
478-
self.backendFuncs.record_event(
479-
start_event, self.collectiveArgs
480-
) # record start event for non-blocking operation
481467
# reset tensor values for data validation check
482468
if dcheck:
483469
self.setTensorVal(self.collectiveArgs.opTensor)
484470
# for blocking mode, do barrier before starting collective
485471
if is_blocking:
486472
self.backendFuncs.sync_barrier(self.collectiveArgs)
487-
self.backendFuncs.record_event(
488-
start_event, self.collectiveArgs
489-
) # record start event for blocking operation
473+
490474
start = time.monotonic() # available only in py3
491475
with paramStreamGuard(
492476
stream=self.backendFuncs.get_current_stream(
@@ -501,47 +485,29 @@ def run_coll_non_graph(self, comm_fn=None, dcheck=False):
501485
]
502486
for _ in range(self.collectiveArgs.numCollPerIter):
503487
comm_fn(self.collectiveArgs)
488+
504489
if is_blocking: # should be sychronous, wait for the collective
505-
self.backendFuncs.record_event(
506-
end_event, self.collectiveArgs
507-
) # record end event for blocking operation
508490
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
509-
elapsedDeviceTimeMs = self.backendFuncs.elapsed_time(
510-
start_event, end_event
511-
)
512-
elapsedDeviceTimeNS += elapsedDeviceTimeMs * 1e6 # Convert ms to ns
491+
513492
# Measuring time.
514-
elapsedCPUTimeNS += (
493+
elapsedTimeNS += (
515494
time.monotonic() - start
516495
) * 1e9 # keeping time in NS, helps in divising data by nanosecond
496+
517497
start = time.monotonic() # available only in py3
518-
# if not blocking, record second end event here
519-
if not is_blocking:
520-
self.backendFuncs.record_event(
521-
end_event, self.collectiveArgs
522-
) # record end event for non-blocking operations
523498
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
524499
end = time.monotonic() # available only in py3
500+
525501
ensureTensorFlush(self.collectiveArgs.opTensor)
526502

527-
elapsedCPUTimeNS += (
503+
elapsedTimeNS += (
528504
end - start
529505
) * 1e9 # keeping time in NS, helps in divising data by nanoseconds
530-
if not is_blocking:
531-
elapsedDeviceTimeMs = self.backendFuncs.elapsed_time(start_event, end_event)
532-
elapsedDeviceTimeNS = elapsedDeviceTimeMs * 1e6 # Convert ms to ns
533506

534507
memSize = self.backendFuncs.get_mem_size(self.collectiveArgs)
535-
logger.debug(
536-
f"elapsedCPUTimeNS={elapsedCPUTimeNS}, elapsedDeviceTimeNS={elapsedDeviceTimeNS}."
537-
)
538-
ElapsedTimeNS = (
539-
elapsedDeviceTimeNS
540-
if self.collectiveArgs.use_device_time
541-
else elapsedCPUTimeNS
542-
)
508+
543509
avgIterNS, algBW = comms_utils.getAlgBW(
544-
ElapsedTimeNS,
510+
elapsedTimeNS,
545511
memSize,
546512
self.collectiveArgs.numIters * self.collectiveArgs.numCollPerIter,
547513
)
@@ -904,7 +870,6 @@ def initCollectiveArgs(self, commsParams):
904870
self.collectiveArgs.asyncOp = False if commsParams.blockingFlag == 1 else True
905871
self.collectiveArgs.numCollPerIter = commsParams.num_coll
906872
self.collectiveArgs.include_0B = commsParams.include_0B
907-
self.collectiveArgs.use_device_time = commsParams.use_device_time
908873
self.collectiveArgs.graph_launches = commsParams.graph_launches
909874

910875
if commsParams.bitwidth < 32:

train/comms/pt/comms_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,6 @@ def __init__(self, args: Namespace) -> None:
818818
self.ibv_devices = args.ibv_devices
819819
self.init_only = args.init_only
820820
self.eager_init = args.eager_init
821-
self.use_device_time = args.use_device_time
822821

823822

824823
class commsDlrmParamsHolder(commsParamsHolderBase):

train/comms/pt/pytorch_backend_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ def __init__(self) -> None:
141141
self.use_ext_dist = False
142142

143143
self.include_0B = False
144-
self.use_device_time = False
145144
self.graph_launches = 0
146145

147146

train/comms/pt/pytorch_dist_backend.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -670,27 +670,6 @@ def complete_accel_ops(self, collectiveArgs, devSync=True):
670670
collectiveArgs.waitObj.clear()
671671
collectiveArgs.waitObjIds.clear()
672672

673-
def create_event(self, collectiveArgs):
674-
dev_str = (
675-
self.commsParams["device"]
676-
if isinstance(self.commsParams, dict)
677-
else self.commsParams.device
678-
)
679-
if dev_str == "cuda":
680-
return torch.cuda.Event(enable_timing=True)
681-
return None
682-
683-
def record_event(self, event, collectiveArgs):
684-
# Check if the start_event is not None, which means it's a CUDA event
685-
if event is not None:
686-
# Record the start event on the current CUDA stream
687-
event.record(self.get_current_stream(device=collectiveArgs.device))
688-
689-
def elapsed_time(self, start_event, end_event):
690-
if start_event is not None and end_event is not None:
691-
return start_event.elapsed_time(end_event)
692-
return 0
693-
694673
# retFlag not used
695674
def complete_single_op(self, collectiveArgs, retFlag=False):
696675
"""only wait on the first op in the queue"""

0 commit comments

Comments
 (0)