diff --git a/.gitignore b/.gitignore index a230a78a..93163afa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .venv/ +.vscode/ __pycache__/ diff --git a/et_replay/comm/backend/base_backend.py b/et_replay/comm/backend/base_backend.py index 9fb708f6..a81b0096 100644 --- a/et_replay/comm/backend/base_backend.py +++ b/et_replay/comm/backend/base_backend.py @@ -127,7 +127,6 @@ class BaseBackend(ABC): def __init__(self) -> None: self.tcp_store = None self.collectiveFunc = { - "all_to_all_single": self.all_to_all_single, # pyre-ignore[16]: "all_to_all": self.all_to_all, "all_to_allv": self.all_to_allv, "all_reduce": self.all_reduce, diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index b80f6e7d..124ff4e9 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -233,7 +233,7 @@ def all_to_all( group=self.get_collective_group(collectiveArgs), async_op=collectiveArgs.asyncOp, ) - + if collectiveArgs.asyncOp: collectiveArgs.waitObj.append(work) @@ -241,6 +241,7 @@ def all_to_all( return work def all_to_allv(self, collectiveArgs, retFlag=False, pair=False): + # cpp layer all_to_allv is corresponding to python layer all_to_all_single # pair=True mode does not support quantization if ( collectiveArgs.all2all_qcomm @@ -301,25 +302,6 @@ def all_to_allv(self, collectiveArgs, retFlag=False, pair=False): if retFlag: return work - def all_to_all_single(self, collectiveArgs, retFlag=False, pair=False): - # does not support quantization - if collectiveArgs.all2all_qcomm: - logger.warn("all_to_all_single does not support quantization") - return - - work = dist.all_to_all_single( - collectiveArgs.opTensor if not pair else collectiveArgs.opTensor_pair, - collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair, - group=collectiveArgs.group, - async_op=collectiveArgs.asyncOp, - ) - - if collectiveArgs.asyncOp: - collectiveArgs.waitObj.append(work) - - if retFlag: - return work - def all_gather(self, collectiveArgs, retFlag=False, pair=False): if self.use_ext_dist: retObj = collectiveArgs.group.all_gather( diff --git a/et_replay/comm/commsTraceParser.py b/et_replay/comm/commsTraceParser.py index a466aa24..7a91f739 100644 --- a/et_replay/comm/commsTraceParser.py +++ b/et_replay/comm/commsTraceParser.py @@ -1,6 +1,7 @@ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. from __future__ import annotations +import math import json import logging @@ -202,7 +203,7 @@ def _parse_comms_op_node( # noqa: C901 comm_args.worldSize = total_ranks comm_args.inSplit = json.loads(node.commArgs.in_split_size) comm_args.outSplit = json.loads(node.commArgs.out_split_size) - + comms_op_list.append(comm_args) return comms_op_list diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index 62b843dd..7566b91d 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -107,7 +107,6 @@ def fixBeginSize(commsParams: commsParamsHolder, world_size: int) -> None: if commsParams.collective in ( "all_to_all", "all_to_allv", - "all_to_all_single", "all_gather", "all_gather_base", "gather", @@ -293,14 +292,13 @@ def checkQuantArgs( if collective not in ( "all_to_all", "all_to_allv", - "all_to_all_single", "reduce", "all_reduce", ): raise NotImplementedError( f"quantized communication for {collective} is currently unsupported." ) - if collective in ("all_to_all", "all_to_allv", "all_to_all_single"): + if collective in ("all_to_all", "all_to_allv"): if (beginSize // 4) % quant_a2a_embedding_dim != 0: logger.warning( f"begin size {beginSize} must be a multiple of --quant-a2a-embedding-dim {quant_a2a_embedding_dim} for all_to_all operation" @@ -342,7 +340,6 @@ def paramToCommName(name: str, supported_comms: list[str] | None = None) -> str: "alltoall": "all_to_all", "alltoallv": "all_to_allv", "alltoallbase": "all_to_allv", - "alltoallsingle": "all_to_all_single", "allreduce": "all_reduce", "allgather": "all_gather", "allgatherbase": "all_gather_base", @@ -873,56 +870,17 @@ def _prep_all_to_allv( opTensor = torch.Tensor() if allocate: # all_to_allv requires two tensors + # ipTensor has been allocated outside of this function, just pass in opTensor = self.backendFuncs.alloc_random( [numElementsOut], curDevice, dtype, scaleFactor ) # recorded splits in trace is only for dim 0, but tensor in replay has been flattened. # need to recalculate the splits for flattened 1D tensor - self.collectiveArgs.opTensor_split = ( - [numElementsOut // sum(curComm.outSplit) * i for i in curComm.outSplit] - if curComm.outSplit - else None - ) - self.collectiveArgs.ipTensor_split = ( - [numElementsIn // sum(curComm.inSplit) * i for i in curComm.inSplit] - if curComm.inSplit - else None - ) - return (ipTensor, opTensor) - - def _prep_all_to_all_single( - self, - ipTensor: torch.Tensor, - curComm: commsArgs, - commsParams: commsParamsHolderBase, - numElementsIn: int, - numElementsOut: int, - world_size: int, - curDevice: str, - dtype: torch.dtype, - scaleFactor: float, - allocate: bool = True, - ) -> tuple[torch.Tensor, torch.Tensor]: - ipTensor = torch.Tensor() - opTensor = torch.Tensor() - if allocate: - if commsParams.dcheck == 1: - ipTensor = self.backendFuncs.alloc_ones( - [numElementsIn], - curDevice, - commsParams.dtype, - self.initVal, - ) - else: - ipTensor = self.backendFuncs.alloc_random( - [numElementsIn], - curDevice, - commsParams.dtype, - scaleFactor, - ) - opTensor = self.backendFuncs.alloc_random( - [numElementsOut], curDevice, dtype, scaleFactor - ) + # corner case: one rank sends zeor data out, but receives data from other ranks, and vice versa. + self.collectiveArgs.opTensor_split = \ + [numElementsOut // max(sum(curComm.outSplit), 1) * i for i in curComm.outSplit] if curComm.outSplit else None + self.collectiveArgs.ipTensor_split = \ + [numElementsIn // max(sum(curComm.inSplit), 1) * i for i in curComm.inSplit] if curComm.inSplit else None return (ipTensor, opTensor) def _prep_all_to_all( @@ -941,19 +899,11 @@ def _prep_all_to_all( ipTensor = [] opTensor = [] if allocate: - alloc_func = ( - self.backendFuncs.alloc_ones - if commsParams.dcheck == 1 - else self.backendFuncs.alloc_random - ) - ipTensor = [ - alloc_func(i, curDevice, commsParams.dtype, self.initVal) - for i in curComm.inSplit - ] - opTensor = [ - alloc_func(i, curDevice, commsParams.dtype, self.initVal) - for i in curComm.outSplit - ] + i_alloc_func = self.backendFuncs.alloc_ones if commsParams.dcheck == 1 else self.backendFuncs.alloc_random + i_scale_factor = self.initVal if commsParams.dcheck == 1 else scaleFactor + ipTensor = [i_alloc_func([i], curDevice, commsParams.dtype, i_scale_factor) for i in curComm.inSplit] + + opTensor = [self.backendFuncs.alloc_random([i], curDevice, commsParams.dtype, scaleFactor) for i in curComm.outSplit] return (ipTensor, opTensor) def _prep_all_gather( @@ -1240,7 +1190,6 @@ def prepComm( # TODO: consider using this dictionary to check valid keywords rather than silently defaulting dispatchDict = { - "all_to_all_single": self._prep_all_to_all_single, "all_to_allv": self._prep_all_to_allv, "all_to_all": self._prep_all_to_all, "all_gather": self._prep_all_gather, diff --git a/et_replay/comm/profiler_trace_analysis.py b/et_replay/comm/profiler_trace_analysis.py index dd5170d2..22e4f8e3 100644 --- a/et_replay/comm/profiler_trace_analysis.py +++ b/et_replay/comm/profiler_trace_analysis.py @@ -2,9 +2,12 @@ import json import logging import os +import re import pathlib from collections import defaultdict from typing import Any, Callable, Dict +import functools +import time import numpy as np from intervaltree import Interval, IntervalTree @@ -12,6 +15,17 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +def timer_decorator(func): + """Decorator that prints the execution time of a function""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + print(f"{func.__name__} took {end_time - start_time:.2f} seconds") + return result + return wrapper + # refer to: # https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/c10/core/ScalarType.h#L61 _dtype_size_map: Dict[str, int] = { @@ -138,8 +152,48 @@ def _get_event_busbw_factor(evt): return correction_factor_func(group_size) - -def calculate_bw_(trace_data): +def _is_uneven_all_to_all_evt(evt): + coll_name = _get_dict_value( + evt["args"], + "Collective name", + f'Missing "Collective name" in event: {evt}' + ) + return (coll_name in ["all_to_all", "all_to_allv"] + and (ast.literal_eval(evt['args']['In split size']) + or ast.literal_eval(evt['args']['Out split size'])) + ) + +def _get_uneven_all_to_all_data_size(evt, global_rank): + group_size = evt["args"]["Group size"] + local_rank = _parse_ranks(evt["args"]["Process Group Ranks"], group_size).index(global_rank) + in_elems_count = evt["args"]["In msg nelems"] + out_elems_count = evt["args"]["Out msg nelems"] + in_split_size = ast.literal_eval(evt["args"]["In split size"]) + out_split_size = ast.literal_eval(evt["args"]["Out split size"]) + dtype_size = _dtype_size_map[evt["args"]["dtype"]] + + if (in_split_size and in_split_size[-1] == Ellipsis) or \ + (out_split_size and out_split_size[-1] == Ellipsis): + in_split_size = [] + out_split_size = [] + logger.warning(f'Fallback to even all2all bw calculation for event: {evt}') + + if in_split_size: + send_elems = in_elems_count - in_split_size[local_rank] + else: + send_elems = in_elems_count / group_size * (group_size - 1) + + if out_split_size: + recv_elems = out_elems_count - out_split_size[local_rank] + else: + recv_elems = out_elems_count / group_size * (group_size - 1) + + return max(send_elems, recv_elems) * dtype_size + +def _calculate_busbw_for_uneven_all_to_all(evt, global_rank): + return round(_get_uneven_all_to_all_data_size(evt, global_rank) / evt["dur"] * 1e-3, 2) + +def calculate_bw_(trace_data, global_rank): nccl_events = [ i for i in trace_data["traceEvents"] @@ -163,7 +217,11 @@ def calculate_bw_(trace_data): algbw = _calculate_algbw(evt) busbw_factor = _get_event_busbw_factor(evt) - busbw = round(algbw * busbw_factor, 2) + if _is_uneven_all_to_all_evt(evt): + # calculate busbw for uneven all_to_all + busbw = _calculate_busbw_for_uneven_all_to_all(evt, global_rank) + else: + busbw = round(algbw * busbw_factor, 2) evt["args"]["algbw (GB/sec)"] = algbw evt["args"]["busbw (GB/sec)"] = busbw @@ -178,7 +236,7 @@ def calculate_bw_(trace_data): logger.error(f"- Error: {err_msg}") -def calculate_sbw(trace_data): +def calculate_sbw(trace_data, global_rank): # calculate shared bw per rank nccl_events = [ i @@ -193,6 +251,8 @@ def calculate_sbw(trace_data): total_data_size = sum( _calculate_event_data_size(evt) * _get_event_busbw_factor(evt) + if not _is_uneven_all_to_all_evt(evt) + else _get_uneven_all_to_all_data_size(evt, global_rank) for evt in nccl_events ) @@ -232,6 +292,13 @@ def pick_iter_e2e_time_(trace_data, tl): def pick_comm_bw_(trace_data, comm_bw_data): rank = trace_data["distributedInfo"]["rank"] + + group_ranks_to_pg_id = defaultdict(list) + for pg in trace_data["distributedInfo"]["pg_config"]: + group_ranks_to_pg_id[tuple(pg["ranks"])].append(int(pg["pg_name"])) + for ranks in group_ranks_to_pg_id: + group_ranks_to_pg_id[ranks].sort() + nccl_events = [ i for i in trace_data["traceEvents"] @@ -239,18 +306,20 @@ def pick_comm_bw_(trace_data, comm_bw_data): and i["name"].startswith(("ncclDevKernel_", "ncclKernel_")) and "algbw (GB/sec)" in i["args"] ] + pg_name2config = {pg["pg_name"]: pg for pg in trace_data["distributedInfo"]["pg_config"]} for evt in nccl_events: knl_name = evt["name"][: evt["name"].index("(")] coll_name = evt["args"]["Collective name"] data_size = _calculate_event_data_size(evt) - ranks_count = evt["args"]["Group size"] - ranks = _parse_ranks(evt["args"]["Process Group Ranks"], ranks_count) + ranks_count = evt["args"]["Group size"] pg_id = int(evt["args"]["Process Group Name"]) - pg = (*ranks, pg_id) if ranks and rank == min(ranks) else None + ranks = pg_name2config[evt["args"]["Process Group Name"]]['ranks'] + + # If there are multiple process groups with the same ranks, the last element + # of this tuple is the idential index to differentiate them across ranks. + pg = (*ranks, group_ranks_to_pg_id[tuple(ranks)].index(pg_id)) - # TODO: calculation of unbalanced all2all bw needs to be improved - # all2all is implemented by single ncclDevKernel_SendRecv() in NCCL comm_bw_data[(knl_name, coll_name, data_size, ranks_count)].append( [ evt["dur"], @@ -260,7 +329,7 @@ def pick_comm_bw_(trace_data, comm_bw_data): ] ) - +@timer_decorator def analyze_profiler_trace(trace_dir: str, report_dir: str): """ Analyse input PyTorch profiler trace (i.e. Kineto trace) and generate report. @@ -282,24 +351,26 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str): # list of shared bw sbw_lst = [] - # key is (kernel_name, data size, ranks number) + # key is (kernel_name, coll name, data size, ranks count) # value is list of [dur, algbw, busbw, pg] comm_bw_data = defaultdict(list) for fpath in os.scandir(trace_dir): if not fpath.is_file(): continue - + with open(fpath.path, "r", encoding="utf-8") as f: trace = json.load(f) - - calculate_bw_(trace) + + global_rank = trace["distributedInfo"]["rank"] + calculate_bw_(trace, global_rank) + with open( os.path.join(processed_trace_dir, fpath.name), "w", encoding="utf-8" ) as f: json.dump(trace, f) - sbw_lst.append(calculate_sbw(trace)) + sbw_lst.append(calculate_sbw(trace, global_rank)) pick_iter_e2e_time_(trace, iter_e2e_time) pick_comm_bw_(trace, comm_bw_data) @@ -330,7 +401,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str): f"avg. E2ETime of iters among all ranks: {sum(iter_e2e_time) / len(iter_e2e_time) / 1e3 :.3f} ms\n" ) f.write( - f"avg. SharedBW (i.e. sum(data_size * busbw_factor) / GPU_comm_busy_time per rank) among all ranks: {sum(sbw_lst) / len(sbw_lst) :.3f} GB/s\n" + f"avg. SharedBW (i.e. sum(busbw_data_size) / GPU_comm_busy_time per rank) among all ranks: {sum(sbw_lst) / len(sbw_lst) :.3f} GB/s\n" ) f.write( @@ -352,9 +423,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str): f.write("\n") for k, v in comm_bw_summary.items(): - f.write( - f"{k[0]:>50s} {k[1]:>15s} {k[2]:>12d} {k[3]:>6d}|{v[0]:>5d}|{v[1]/1e3:>10.3f} " - ) + f.write(f"{k[0]:>50s} {k[1]:>15s} {k[2]:>12d} {k[3]:>6d}|{v[0]:>5d}|{v[1]/1e3:>10.3f} ") for i in range(2, len(v)): f.write(f"{v[i]:>8.2f}|") f.write("\n") diff --git a/et_replay/pyproject.toml b/et_replay/pyproject.toml index 19dbf0ea..00811ee1 100644 --- a/et_replay/pyproject.toml +++ b/et_replay/pyproject.toml @@ -8,6 +8,7 @@ version = "0.5.0" dependencies = [ "numpy", "intervaltree", + "pydot", ] [tool.setuptools.package-dir]