diff --git a/torchrec/distributed/benchmark/README.md b/torchrec/distributed/benchmark/README.md index 9de7a09f2..287e642da 100644 --- a/torchrec/distributed/benchmark/README.md +++ b/torchrec/distributed/benchmark/README.md @@ -12,3 +12,15 @@ python -m torchrec.distributed.benchmark.benchmark_train_pipeline \ --yaml_config=fbcode/torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml \ --name=sparse_data_dist_base_$(git rev-parse --short HEAD || echo $USER) # overrides the yaml config ``` + +## benchmark_comms usage +- internal: +``` +buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms -- \ + a2a_single --name=a2a_sync_base-$(hg whereami | cut -c 1-10) +``` +- oss: +``` +python -m torchrec.distributed.benchmark.benchmark_comms \ + a2a_single --name=a2a_sync_base-$(git rev-parse --short HEAD || echo $USER) +``` diff --git a/torchrec/distributed/benchmark/base.py b/torchrec/distributed/benchmark/base.py index 806197d0f..c9eaca711 100644 --- a/torchrec/distributed/benchmark/base.py +++ b/torchrec/distributed/benchmark/base.py @@ -11,11 +11,13 @@ #!/usr/bin/env python3 import argparse +import functools import inspect import json import logging import os import resource +import sys import time import timeit from dataclasses import dataclass, fields, is_dataclass, MISSING @@ -362,137 +364,210 @@ def set_embedding_config( return embedding_configs, pooling_configs -# pyre-ignore [24] -def cmd_conf(func: Callable) -> Callable: - - def _load_config_file(config_path: str, is_json: bool = False) -> Dict[str, Any]: - if not config_path: - return {} - - try: - with open(config_path, "r") as f: - if is_json: - return json.load(f) or {} - else: - return yaml.safe_load(f) or {} - except Exception as e: - logger.error(f"Failed to load config because {e}. Proceeding without it.") - return {} - - # pyre-ignore [3] - def wrapper() -> Any: - sig = inspect.signature(func) - parser = argparse.ArgumentParser(func.__doc__) - - parser.add_argument( - "--yaml_config", - type=str, - default=None, - help="YAML config file for benchmarking", - ) - - parser.add_argument( - "--json_config", - type=str, - default=None, - help="JSON config file for benchmarking", - ) - - # Add loglevel argument with current logger level as default - parser.add_argument( - "--loglevel", - type=str, - default=logging._levelToName[logger.level], - help="Set the logging level (e.g. info, debug, warning, error)", - ) - - pre_args, _ = parser.parse_known_args() - - yaml_defaults: Dict[str, Any] = ( - _load_config_file(pre_args.yaml_config, is_json=False) - if pre_args.yaml_config - else {} - ) - json_defaults: Dict[str, Any] = ( - _load_config_file(pre_args.json_config, is_json=True) - if pre_args.json_config - else {} - ) - # Merge the two dictionaries, JSON overrides YAML - merged_defaults = {**yaml_defaults, **json_defaults} - - seen_args = set() # track all -- we've added +class cmd_conf: + """ + Decorator for run functions in command line. + parse input arguments into the function's arguments and config (dataclass) + + Example 1: direct decorating (see the overloaded __new__ method below) + ``` + @cmd_conf # you might need "pyre-ignore [56]" + def main( + run_option: RunOptions, + table_config: EmbeddingTablesConfig, + model_selection: ModelSelectionConfig, + pipeline_config: PipelineConfig, + model_config: Optional[BaseModelConfig] = None, + integer: int + ) -> None: + pass + + if __name__ == "__main__": + main() + ``` + + Example 2: register multiple function + invoke with: -- (run1|run2) --arg1=... + ``` + _cc = cmd_conf() + @_cc.register + def func1(input_config: CONF1): + pass + + @_cc.register + def func2(input_config: CONF2): + pass + + if __name__ == "__main__": + _cc.main() + ``` + """ - for _name, param in sig.parameters.items(): - cls = param.annotation - if not is_dataclass(cls): - continue + def __init__(self) -> None: + # pyre-ignore [24] + self.programs: Dict[str, Callable] = {} - for f in fields(cls): - arg_name = f.name - if arg_name in seen_args: - logger.warning(f"WARNING: duplicate argument {arg_name}") - continue - seen_args.add(arg_name) - - ftype = f.type - origin = get_origin(ftype) - - # Unwrapping Optional[X] to X - if origin is Union and type(None) in get_args(ftype): - non_none = [t for t in get_args(ftype) if t is not type(None)] - if len(non_none) == 1: - ftype = non_none[0] - origin = get_origin(ftype) - - # Handle default_factory value and allow config to override - default_value = merged_defaults.get( - arg_name, # flat lookup - merged_defaults.get(cls.__name__, {}).get( # hierarchy lookup - arg_name, - ( - f.default_factory() # pyre-ignore [29] - if f.default_factory is not MISSING - else f.default - ), - ), + @classmethod + # pyre-ignore [24] + def __new__(cls, _, func: Optional[Callable] = None) -> Union["cmd_conf", Callable]: + if not func: + return super().__new__(cls) + else: + return cmd_conf.call(func) + + @staticmethod + def call(func: Callable) -> Callable: # pyre-ignore [24] + + def _load_config_file( + config_path: str, is_json: bool = False + ) -> Dict[str, Any]: + if not config_path: + return {} + + try: + with open(config_path, "r") as f: + if is_json: + return json.load(f) or {} + else: + return yaml.safe_load(f) or {} + except Exception as e: + logger.error( + f"Failed to load config because {e}. Proceeding without it." ) + return {} + + @functools.wraps(func) + def wrapper() -> Any: # pyre-ignore [3] + sig = inspect.signature(func) + parser = argparse.ArgumentParser(func.__doc__) + + parser.add_argument( + "--yaml_config", + type=str, + default=None, + help="YAML config file for benchmarking", + ) - arg_kwargs = { - "default": default_value, - "help": f"({cls.__name__}) {arg_name}", - } - - if origin in (list, List): - elem_type = get_args(ftype)[0] - arg_kwargs.update(nargs="*", type=elem_type) - elif ftype is bool: - # Special handling for boolean arguments - arg_kwargs.update(type=lambda x: x.lower() in ["true", "1", "yes"]) - else: - arg_kwargs.update(type=ftype) + parser.add_argument( + "--json_config", + type=str, + default=None, + help="JSON config file for benchmarking", + ) - parser.add_argument(f"--{arg_name}", **arg_kwargs) + # Add loglevel argument with current logger level as default + parser.add_argument( + "--loglevel", + type=str, + default=logging._levelToName[logger.level], + help="Set the logging level (e.g. info, debug, warning, error)", + ) - args = parser.parse_args() - logger.setLevel(logging.INFO) + pre_args, _ = parser.parse_known_args() - # Build the dataclasses - kwargs = {} - for name, param in sig.parameters.items(): - cls = param.annotation - if is_dataclass(cls): - data = {f.name: getattr(args, f.name) for f in fields(cls)} - config_instance = cls(**data) # pyre-ignore [29] - kwargs[name] = config_instance - logger.info(config_instance) + yaml_defaults: Dict[str, Any] = ( + _load_config_file(pre_args.yaml_config, is_json=False) + if pre_args.yaml_config + else {} + ) + json_defaults: Dict[str, Any] = ( + _load_config_file(pre_args.json_config, is_json=True) + if pre_args.json_config + else {} + ) + # Merge the two dictionaries, JSON overrides YAML + merged_defaults = {**yaml_defaults, **json_defaults} - loglevel = logging._nameToLevel[args.loglevel.upper()] - logger.setLevel(loglevel) + seen_args = set() # track all -- we've added - return func(**kwargs) + for _name, param in sig.parameters.items(): + cls = param.annotation + if not is_dataclass(cls): + continue - return wrapper + for f in fields(cls): + arg_name = f.name + if arg_name in seen_args: + logger.warning(f"WARNING: duplicate argument {arg_name}") + continue + seen_args.add(arg_name) + + ftype = f.type + origin = get_origin(ftype) + + # Unwrapping Optional[X] to X + if origin is Union and type(None) in get_args(ftype): + non_none = [t for t in get_args(ftype) if t is not type(None)] + if len(non_none) == 1: + ftype = non_none[0] + origin = get_origin(ftype) + + # Handle default_factory value and allow config to override + default_value = merged_defaults.get( + arg_name, # flat lookup + merged_defaults.get(cls.__name__, {}).get( # hierarchy lookup + arg_name, + ( + f.default_factory() # pyre-ignore [29] + if f.default_factory is not MISSING + else f.default + ), + ), + ) + + arg_kwargs = { + "default": default_value, + "help": f"({cls.__name__}) {arg_name}", + } + + if origin in (list, List): + elem_type = get_args(ftype)[0] + arg_kwargs.update(nargs="*", type=elem_type) + elif ftype is bool: + # Special handling for boolean arguments + arg_kwargs.update( + type=lambda x: x.lower() in ["true", "1", "yes"] + ) + else: + arg_kwargs.update(type=ftype) + + parser.add_argument(f"--{arg_name}", **arg_kwargs) + + args = parser.parse_args() + logger.setLevel(logging.INFO) + + # Build the dataclasses + kwargs = {} + for name, param in sig.parameters.items(): + cls = param.annotation + if is_dataclass(cls): + data = {f.name: getattr(args, f.name) for f in fields(cls)} + config_instance = cls(**data) # pyre-ignore [29] + kwargs[name] = config_instance + logger.info(config_instance) + + loglevel = logging._nameToLevel[args.loglevel.upper()] + logger.setLevel(loglevel) + + return func(**kwargs) + + return wrapper + + # pyre-ignore [24] + def register(self, func: Callable) -> Callable: + wrapper = cmd_conf.call(func) + self.programs[func.__name__] = wrapper + return wrapper + + def main(self) -> None: + program = sys.argv[1] + if program in self.programs: + sys.argv[:] = [sys.argv[0]] + (sys.argv[2:] if len(sys.argv) > 2 else []) + self.programs[program]() + else: + print( + f"Invalid command. Please use select program from {', '.join(self.programs.keys())}." + ) def init_argparse_and_args() -> argparse.Namespace: @@ -534,12 +609,15 @@ def _run_benchmark_core( Args: name: Human-readable benchmark name. + run_iter_fn: Zero-arg callable that executes one measured iteration. profile_iter_fn: Optional callable that receives a ``torch.profiler`` instance and runs the iterations that should be captured. + world_size, rank: Distributed context to correctly reset / collect GPU stats. ``rank == -1`` means single-process mode. num_benchmarks: Number of measured iterations. + device_type: "cuda" or "cpu". output_dir: Where to write chrome traces / stack files. pre_gpu_load: Number of dummy matmul operations to run before the first @@ -597,8 +675,10 @@ def _run_benchmark_core( cpu_times_active_ns.append(cpu_end_active_ns - cpu_start_active_ns) # Convert to milliseconds and drop the first iteration - cpu_elapsed_time = torch.tensor( - [t / 1e6 for t in cpu_times_active_ns[1:]], dtype=torch.float + cpu_elapsed_time = ( + torch.tensor([t / 1e6 for t in cpu_times_active_ns[1:]], dtype=torch.float) + if num_benchmarks >= 2 + else torch.zeros(1, dtype=torch.float) ) # Make sure all kernels are finished before reading timers / stats @@ -608,8 +688,12 @@ def _run_benchmark_core( else: torch.cuda.synchronize(rank) - gpu_elapsed_time = torch.tensor( - [s.elapsed_time(e) for s, e in zip(start_events[1:], end_events[1:])] + gpu_elapsed_time = ( + torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_events[1:], end_events[1:])] + ) + if num_benchmarks >= 2 + else torch.zeros(1, dtype=torch.float) ) else: # For CPU-only benchmarks we fall back to wall-clock timing via ``timeit``. @@ -734,18 +818,42 @@ def _profile_iter_fn(prof: torch.profiler.profile) -> None: ) +@dataclass +class BenchFuncConfig: + name: str + world_size: int + num_profiles: int + num_benchmarks: int + profile_dir: str + device_type: str = "cuda" + pre_gpu_load: int = 0 + export_stacks: bool = False + + # pyre-ignore [2] + def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]: + return { + "name": self.name, + "world_size": self.world_size, + "num_profiles": self.num_profiles, + "num_benchmarks": self.num_benchmarks, + "profile_dir": self.profile_dir, + "device_type": self.device_type, + "pre_gpu_load": self.pre_gpu_load, + "export_stacks": self.export_stacks, + } | kwargs_to_override + + def benchmark_func( name: str, + rank: int, + world_size: int, + func_to_benchmark: Any, # pyre-ignore[2] bench_inputs: List[Dict[str, Any]], prof_inputs: List[Dict[str, Any]], - world_size: int, - profile_dir: str, - num_benchmarks: int, - num_profiles: int, - # pyre-ignore[2] - func_to_benchmark: Any, benchmark_func_kwargs: Optional[Dict[str, Any]], - rank: int, + num_profiles: int, + num_benchmarks: int, + profile_dir: str, device_type: str = "cuda", pre_gpu_load: int = 0, export_stacks: bool = False, @@ -753,19 +861,21 @@ def benchmark_func( """ Args: name: Human-readable benchmark name. + world_size, rank: Distributed context to correctly reset / collect GPU + stats. ``rank == -1`` means single-process mode. - bench_inputs: List[Dict[str, Any]] will be fed to the function at once - prof_inputs: List[Dict[str, Any]] will be fed to the function at once - benchmark_func_kwargs: kwargs to be passed to func_to_benchmark func_to_benchmark: Callable that executes one measured iteration. func_to_benchmark(batch_inputs, **kwargs) + bench_inputs, prof_inputs: List[Dict[str, Any]] this argument will be fed + to the function at once, and bench_inputs will be used for benchmarking + while prof_inputs will be used for profiling + benchmark_func_kwargs: kwargs to be passed to func_to_benchmark - world_size, rank: Distributed context to correctly reset / collect GPU - stats. ``rank == -1`` means single-process mode. - num_benchmarks: Number of measured iterations. - device_type: "cuda" or "cpu". + num_profiles, num_benchmarks: Number of measured iterations, i.e., how many + times the function will be called profile_dir: Where to write chrome traces / stack files. + device_type: "cuda" or "cpu". pre_gpu_load: Number of dummy matmul operations to run before the first measured iteration (helps simulating a loaded allocator). export_stacks: Whether to export flamegraph-compatible stack files. diff --git a/torchrec/distributed/benchmark/benchmark_comms.py b/torchrec/distributed/benchmark/benchmark_comms.py new file mode 100644 index 000000000..66725396d --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_comms.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +Example usage: + +Buck2 (internal): + buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms -- + +OSS (external): + python -m torchrec.distributed.benchmark.benchmark_comms + +""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.distributed as dist + +from torch.autograd.profiler import record_function + +from torchrec.distributed.benchmark.base import ( + BenchFuncConfig, + benchmark_func, + cmd_conf, +) +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + run_multi_process_func, +) + +_cc = cmd_conf() + + +@dataclass +class AllToAllSingleRunConfig(BenchFuncConfig): + name: str = "all_to_all_single" + world_size: int = 2 + dim: int = 2048 + profile_dir: str = "." + num_benchmarks: int = 1 + num_profiles: int = 2 + num_mul: int = 10 + num_concat: int = 100 + + +# all_to_all_single with sync and single stream +def a2a_sync_base( + batch_inputs: List[Dict[str, Any]], + dim: int, + num_mul: int, + num_concat: int, + ctx: MultiProcessContext, +) -> None: + with record_function("## pre-comms compute ##"): + pre_comms = torch.rand(dim, dim, device=ctx.device) - 0.5 + for _ in range(num_mul): + pre_comms = pre_comms @ pre_comms + pre_comms = torch.sigmoid(pre_comms - torch.mean(pre_comms)) + pre_comms = torch.sigmoid(pre_comms).reshape(1, dim, dim) + ctx.rank + pre_comms = torch.concat([pre_comms] * num_concat) + + with record_function("## all_to_all_single ##"): + post_comms = torch.empty_like(pre_comms) + req = dist.all_to_all_single(output=post_comms, input=pre_comms, group=ctx.pg) + + with record_function("## comms validation ##"): + mixed_ranks = post_comms.to(torch.int).reshape(-1) + N = mixed_ranks.numel() // ctx.world_size + checks = [ + torch.all(mixed_ranks[i * N : (i + 1) * N] == i) + for i in range(ctx.world_size) + ] + + with record_function("## irrelevant compute ##"): + pre_comms = torch.rand(dim, dim, device=ctx.device) - 0.5 + for _ in range(num_mul): + pre_comms = pre_comms @ pre_comms + pre_comms = torch.sigmoid(pre_comms - torch.mean(pre_comms)) + pre_comms = torch.sigmoid(pre_comms) + ctx.rank + + with record_function("## post-comms compute ##"): + post_comms = post_comms[0] + for _ in range(num_mul): + post_comms = post_comms @ post_comms + post_comms = torch.sigmoid(pre_comms - torch.mean(post_comms)) + post_comms = torch.sigmoid(post_comms) + ctx.rank + + with record_function("## assert ##"): + assert all(checks) + + +# single-rank runner +def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) -> None: + # Ensure GPUs are available and we have enough of them + assert ( + torch.cuda.is_available() and torch.cuda.device_count() >= world_size + ), "CUDA not available or insufficient GPUs for the requested world_size" + + torch.autograd.set_detect_anomaly(True) + with MultiProcessContext( + rank=rank, + world_size=world_size, + backend="nccl", + use_deterministic_algorithms=False, + ) as ctx: + + if arg.name.startswith("a2a_sync_base"): + func = a2a_sync_base + else: + func = a2a_sync_base + + result = benchmark_func( + bench_inputs=[], + prof_inputs=[], + benchmark_func_kwargs={ + "ctx": ctx, + "dim": arg.dim, + "num_mul": arg.num_mul, + "num_concat": arg.num_concat, + }, + func_to_benchmark=func, + rank=rank, + **arg.benchmark_func_kwargs() + ) + + if rank == 0: + print(result) + + +@_cc.register +def a2a_single(arg: AllToAllSingleRunConfig) -> None: + run_multi_process_func(func=a2a_single_runner, world_size=arg.world_size, arg=arg) + + +if __name__ == "__main__": + _cc.main() diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index dfa3d360f..557d5dd9c 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -27,6 +27,7 @@ from fbgemm_gpu.split_embedding_configs import EmbOptimType from torch import nn from torchrec.distributed.benchmark.base import ( + BenchFuncConfig, benchmark_func, BenchmarkResult, cmd_conf, @@ -58,7 +59,7 @@ @dataclass -class RunOptions: +class RunOptions(BenchFuncConfig): """ Configuration options for running sparse neural network benchmarks. @@ -102,6 +103,8 @@ class RunOptions: input_type: str = "kjt" name: str = "" profile_dir: str = "" + num_benchmarks: int = 5 + num_profiles: int = 2 planner_type: str = "embedding" pooling_factors: Optional[List[float]] = None num_poolings: Optional[List[float]] = None @@ -257,22 +260,17 @@ def _func_to_benchmark( opt=optimizer, device=ctx.device, ) - pipeline.progress(iter(bench_inputs)) + pipeline.progress(iter(bench_inputs)) # warmup + run_option.name = ( + type(pipeline).__name__ if run_option.name == "" else run_option.name + ) result = benchmark_func( - name=( - type(pipeline).__name__ if run_option.name == "" else run_option.name - ), bench_inputs=bench_inputs, # pyre-ignore prof_inputs=bench_inputs, # pyre-ignore - num_benchmarks=5, - num_profiles=2, - profile_dir=run_option.profile_dir, - world_size=run_option.world_size, func_to_benchmark=_func_to_benchmark, benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline}, - rank=rank, - export_stacks=run_option.export_stacks, + **run_option.benchmark_func_kwargs(rank=rank) ) if rank == 0: @@ -325,7 +323,7 @@ def run_pipeline( # command-line interface -@cmd_conf +@cmd_conf # pyre-ignore [56] def main( run_option: RunOptions, table_config: EmbeddingTablesConfig,