Skip to content

Commit 6b86d24

Browse files
authored
Merge branch 'main' into sanshang/fix_all2all
2 parents 263ac16 + a437fce commit 6b86d24

File tree

5 files changed

+114
-119
lines changed

5 files changed

+114
-119
lines changed

et_replay/execution_trace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,8 @@ def __init__(self, json):
357357
input_tensors = self.nodes[id].get_input_tensors()
358358
output_tensors = self.nodes[id].get_output_tensors()
359359

360-
# track the various process and threads we have
361-
if x["name"] == "__ROOT_THREAD__":
360+
# track annonation to get thread ids of root nodes
361+
if x["name"] == "[pytorch|profiler|execution_trace|thread]":
362362
tid = self.nodes[id].tid
363363
self.proc_group[pid][tid] = id
364364

train/comms/pt/comms.py

Lines changed: 38 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,6 @@ def readArgs(self, parser):
183183
default=False,
184184
help="Select some ranks to send/receive 0B messages",
185185
)
186-
parser.add_argument(
187-
"--use-device-time",
188-
action="store_true",
189-
default=False,
190-
help="use device time measurement",
191-
)
192186
parser.add_argument(
193187
"--graph-launches",
194188
type=int,
@@ -368,7 +362,9 @@ def run_coll_cuda_graph(self, comm_fn=None, dcheck=False):
368362
self.backendFuncs.sync_barrier(
369363
self.collectiveArgs, desc="run_coll_cuda_graph_begin"
370364
)
371-
elapsedTimeNS = 0.0
365+
elapsedCPUTimeNS = 0.0
366+
start_event = self.backendFuncs.create_event(self.collectiveArgs)
367+
end_event = self.backendFuncs.create_event(self.collectiveArgs)
372368

373369
# 1. Warmup phase
374370
# launch collective on a separate stream and sync with current_stream
@@ -393,21 +389,34 @@ def run_coll_cuda_graph(self, comm_fn=None, dcheck=False):
393389

394390
# 3. Replay
395391
start = time.monotonic() # available only in py3
392+
self.backendFuncs.record_event(start_event, self.collectiveArgs)
396393
for _ in range(self.collectiveArgs.graph_launches):
397394
if self.collectiveArgs.enable_profiler:
398395
comms_utils.sampleProfiler()
399396

400397
# [optional] we can feed new input data to ipTensor for each replay
401398
g.replay()
402399

400+
self.backendFuncs.record_event(end_event, self.collectiveArgs)
403401
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
402+
404403
end = time.monotonic() # available only in py3
405404

406405
ensureTensorFlush(self.collectiveArgs.opTensor)
407406

408-
elapsedTimeNS += (
407+
elapsedCPUTimeNS += (
409408
end - start
410409
) * 1e9 # keeping time in NS, helps in divising data by nanoseconds
410+
elapsedDeviceTimeMs = self.backendFuncs.elapsed_time(start_event, end_event)
411+
elapsedDeviceTimeNS = elapsedDeviceTimeMs * 1e6
412+
elapsedTimeNS = (
413+
elapsedDeviceTimeNS
414+
if self.collectiveArgs.use_device_time
415+
else elapsedCPUTimeNS
416+
)
417+
logger.debug(
418+
f"elapsedCPUTimeNS={elapsedCPUTimeNS}, elapsedDeviceTimeNS={elapsedDeviceTimeNS}."
419+
)
411420

412421
memSize = self.backendFuncs.get_mem_size(self.collectiveArgs)
413422

@@ -436,17 +445,11 @@ def run_coll_cuda_graph(self, comm_fn=None, dcheck=False):
436445
}
437446
return results
438447

439-
def runColl(self, comm_fn=None, dcheck=False):
440-
if self.collectiveArgs.graph_launches > 0:
441-
return self.run_coll_cuda_graph(comm_fn, dcheck)
448+
def run_coll_non_graph(self, comm_fn=None, dcheck=False):
442449
self.backendFuncs.sync_barrier(self.collectiveArgs, desc="runColl_begin")
443450

444-
elapsedCPUTimeNS = 0.0
445-
elapsedDeviceTimeNS = 0.0
451+
elapsedTimeNS = 0.0
446452
is_blocking = not self.collectiveArgs.asyncOp
447-
# Initialize CUDA events for device timing
448-
start_event = self.backendFuncs.create_event(self.collectiveArgs)
449-
end_event = self.backendFuncs.create_event(self.collectiveArgs)
450453

451454
for nIter in range(
452455
self.collectiveArgs.numWarmupIters + self.collectiveArgs.numIters
@@ -458,22 +461,16 @@ def runColl(self, comm_fn=None, dcheck=False):
458461
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
459462
ensureTensorFlush(self.collectiveArgs.opTensor)
460463
# Start measuring time after warmup iterations
461-
elapsedCPUTimeNS = 0.0
462-
elapsedDeviceTimeNS = 0.0
464+
elapsedTimeNS = 0.0
463465
self.collectiveArgs.quant_time.reset()
464466
self.collectiveArgs.dequant_time.reset()
465-
self.backendFuncs.record_event(
466-
start_event, self.collectiveArgs
467-
) # record start event for non-blocking operation
468467
# reset tensor values for data validation check
469468
if dcheck:
470469
self.setTensorVal(self.collectiveArgs.opTensor)
471470
# for blocking mode, do barrier before starting collective
472471
if is_blocking:
473472
self.backendFuncs.sync_barrier(self.collectiveArgs)
474-
self.backendFuncs.record_event(
475-
start_event, self.collectiveArgs
476-
) # record start event for blocking operation
473+
477474
start = time.monotonic() # available only in py3
478475
with paramStreamGuard(
479476
stream=self.backendFuncs.get_current_stream(
@@ -488,47 +485,29 @@ def runColl(self, comm_fn=None, dcheck=False):
488485
]
489486
for _ in range(self.collectiveArgs.numCollPerIter):
490487
comm_fn(self.collectiveArgs)
488+
491489
if is_blocking: # should be sychronous, wait for the collective
492-
self.backendFuncs.record_event(
493-
end_event, self.collectiveArgs
494-
) # record end event for blocking operation
495490
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
496-
elapsedDeviceTimeMs = self.backendFuncs.elapsed_time(
497-
start_event, end_event
498-
)
499-
elapsedDeviceTimeNS += elapsedDeviceTimeMs * 1e6 # Convert ms to ns
491+
500492
# Measuring time.
501-
elapsedCPUTimeNS += (
493+
elapsedTimeNS += (
502494
time.monotonic() - start
503495
) * 1e9 # keeping time in NS, helps in divising data by nanosecond
496+
504497
start = time.monotonic() # available only in py3
505-
# if not blocking, record second end event here
506-
if not is_blocking:
507-
self.backendFuncs.record_event(
508-
end_event, self.collectiveArgs
509-
) # record end event for non-blocking operations
510498
self.backendFuncs.complete_accel_ops(self.collectiveArgs)
511499
end = time.monotonic() # available only in py3
500+
512501
ensureTensorFlush(self.collectiveArgs.opTensor)
513502

514-
elapsedCPUTimeNS += (
503+
elapsedTimeNS += (
515504
end - start
516505
) * 1e9 # keeping time in NS, helps in divising data by nanoseconds
517-
if not is_blocking:
518-
elapsedDeviceTimeMs = self.backendFuncs.elapsed_time(start_event, end_event)
519-
elapsedDeviceTimeNS = elapsedDeviceTimeMs * 1e6 # Convert ms to ns
520506

521507
memSize = self.backendFuncs.get_mem_size(self.collectiveArgs)
522-
logger.debug(
523-
f"elapsedCPUTimeNS={elapsedCPUTimeNS}, elapsedDeviceTimeNS={elapsedDeviceTimeNS}."
524-
)
525-
ElapsedTimeNS = (
526-
elapsedDeviceTimeNS
527-
if self.collectiveArgs.use_device_time
528-
else elapsedCPUTimeNS
529-
)
508+
530509
avgIterNS, algBW = comms_utils.getAlgBW(
531-
ElapsedTimeNS,
510+
elapsedTimeNS,
532511
memSize,
533512
self.collectiveArgs.numIters * self.collectiveArgs.numCollPerIter,
534513
)
@@ -550,6 +529,13 @@ def runColl(self, comm_fn=None, dcheck=False):
550529
}
551530
return results
552531

532+
def runColl(self, comm_fn=None, dcheck=False):
533+
return (
534+
self.run_coll_non_graph(comm_fn, dcheck)
535+
if self.collectiveArgs.graph_launches == 0
536+
else self.run_coll_cuda_graph(comm_fn, dcheck)
537+
)
538+
553539
def runPt2Pt(self):
554540
self.backendFuncs.sync_barrier(self.collectiveArgs)
555541
# warm-up
@@ -884,7 +870,6 @@ def initCollectiveArgs(self, commsParams):
884870
self.collectiveArgs.asyncOp = False if commsParams.blockingFlag == 1 else True
885871
self.collectiveArgs.numCollPerIter = commsParams.num_coll
886872
self.collectiveArgs.include_0B = commsParams.include_0B
887-
self.collectiveArgs.use_device_time = commsParams.use_device_time
888873
self.collectiveArgs.graph_launches = commsParams.graph_launches
889874

890875
if commsParams.bitwidth < 32:
@@ -911,11 +896,7 @@ def gatherBenchTime(self, collectiveArgs, commsParams, timeUsElapsedList):
911896
# Push the list to device, then do an all-gather.
912897
timeElapsedTensor = torch.tensor(
913898
timeUsElapsedList,
914-
device=(
915-
self.backendFuncs.get_device()
916-
if commsParams.backend == "nccl"
917-
else torch.device("cpu")
918-
),
899+
device=(self.backendFuncs.get_device()),
919900
)
920901
collectiveArgs.opTensor = None
921902
if commsParams.backend != "xla":
@@ -1051,10 +1032,7 @@ def reportBenchTime(
10511032
dequantTimeTensorList,
10521033
):
10531034
# convernt num_elements to # of elements per rank
1054-
if commsParams.collective in (
1055-
"all_to_all",
1056-
"all_to_allv",
1057-
"all_to_all_single",
1035+
if "all_to_all" in commsParams.collective or commsParams.collective in (
10581036
"reduce_scatter",
10591037
"reduce_scatter_v",
10601038
"reduce_scatter_base",

train/comms/pt/comms_utils.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,23 @@
1717
from collections.abc import Callable
1818
from contextlib import ContextDecorator
1919
from io import StringIO
20-
from typing import Any, Dict, List, Optional, Tuple, Union
20+
from typing import Any
21+
22+
import torch
23+
24+
from param_bench.train.comms.pt.param_profile import paramTimer
25+
from param_bench.train.comms.pt.pytorch_backend_utils import (
26+
backendFunctions,
27+
collectiveArgsHolder,
28+
customized_backend,
29+
supportedC10dBackends,
30+
supportedDevices,
31+
)
32+
from torch._C._distributed_c10d import ProcessGroup
33+
34+
random.seed()
35+
36+
logger = logging.getLogger(__name__)
2137

2238
try:
2339
from param_bench.train.comms.pt.fb.internals import (
@@ -29,24 +45,28 @@
2945
)
3046

3147
has_internal_libs = True
48+
logger.info("Successfully import internal libs")
3249
except ImportError:
3350
has_internal_libs = False
51+
logger.info("Iinternal libs not found.")
52+
53+
try:
54+
from param_bench.train.comms.pt.fb.mixins import (
55+
name_aliases_ext,
56+
ParamCommsBenchMixin,
57+
)
3458

59+
logger.info("Successfully imported ParamCommsBenchMixin")
60+
except ImportError:
61+
logger.warning(
62+
"ParamCommsBenchMixin does not exist or module not found. Default to empty class."
63+
)
3564

36-
import torch
37-
from param_bench.train.comms.pt.param_profile import paramTimer
38-
from param_bench.train.comms.pt.pytorch_backend_utils import (
39-
backendFunctions,
40-
collectiveArgsHolder,
41-
customized_backend,
42-
supportedC10dBackends,
43-
supportedDevices,
44-
)
45-
from torch._C._distributed_c10d import ProcessGroup
65+
class ParamCommsBenchMixin:
66+
pass # Define empty class if it does not exist
4667

47-
random.seed()
68+
name_aliases_ext = {}
4869

49-
logger = logging.getLogger(__name__)
5070

5171
default_master_ip = "127.0.0.1"
5272
default_master_port = "29500"
@@ -206,10 +226,7 @@ def fixBeginSize(commsParams: commsParamsHolder, world_size: int) -> None:
206226
None
207227
"""
208228
# ensures we will have atleast one member/rank
209-
if commsParams.collective in (
210-
"all_to_all",
211-
"all_to_allv",
212-
"all_to_all_single",
229+
if "all_to_all" in commsParams.collective or commsParams.collective in (
213230
"all_gather",
214231
"all_gather_base",
215232
"gather",
@@ -392,17 +409,14 @@ def checkQuantArgs(
392409
Returns:
393410
None
394411
"""
395-
if collective not in (
396-
"all_to_all",
397-
"all_to_allv",
398-
"all_to_all_single",
412+
if "all_to_all" not in collective and collective not in (
399413
"reduce",
400414
"all_reduce",
401415
):
402416
raise NotImplementedError(
403417
f"quantized communication for {collective} is currently unsupported."
404418
)
405-
if collective in ("all_to_all", "all_to_allv", "all_to_all_single"):
419+
if "all_to_all" in collective:
406420
if (beginSize // 4) % quant_a2a_embedding_dim != 0:
407421
logger.warning(
408422
f"begin size {beginSize} must be a multiple of --quant-a2a-embedding-dim {quant_a2a_embedding_dim} for all_to_all operation"
@@ -452,6 +466,7 @@ def paramToCommName(name: str, supported_comms: list[str] = None) -> str:
452466
"reducescatterbase": "reduce_scatter_base",
453467
"recvanysource": "recv",
454468
}
469+
name_aliases.update(name_aliases_ext)
455470

456471
new_name = name.lower()
457472

@@ -803,7 +818,6 @@ def __init__(self, args: Namespace) -> None:
803818
self.ibv_devices = args.ibv_devices
804819
self.init_only = args.init_only
805820
self.eager_init = args.eager_init
806-
self.use_device_time = args.use_device_time
807821

808822

809823
class commsDlrmParamsHolder(commsParamsHolderBase):
@@ -934,7 +948,7 @@ def __init__(
934948
self.bag_size = args.bag_size
935949

936950

937-
class paramCommsBench(ABC):
951+
class ParamCommsBenchBase(ABC):
938952
"""Abstract class for any param comms benchmark."""
939953

940954
def __init__(self, supportedNwstacks: list[str] = None) -> None:
@@ -1570,6 +1584,8 @@ def prepComm(
15701584
"scatter": self._prep_reduce_scatter,
15711585
"pt2pt": self._prep_pt2pt,
15721586
}
1587+
if hasattr(self, "dispatchDictExt") and self.dispatchDictExt is not None:
1588+
dispatchDict.update(self.dispatchDictExt)
15731589

15741590
function_to_call = dispatchDict.get(commOp)
15751591
if function_to_call is not None:
@@ -1816,6 +1832,11 @@ def checkArgs(self, args: Namespace) -> None:
18161832
os.environ["MASTER_PORT"] = args.master_port
18171833

18181834

1835+
class paramCommsBench(ParamCommsBenchMixin, ParamCommsBenchBase):
1836+
def __init__(self, supportedNwstacks: list[str] = None) -> None:
1837+
super().__init__(supportedNwstacks)
1838+
1839+
18191840
def init_emb_lookup(collectiveArgs, commsParams, backendFuncs):
18201841
"""
18211842
Initialize embedding table op

0 commit comments

Comments
 (0)