diff --git a/autoparallel/_testing/__init__.py b/autoparallel/_testing/__init__.py new file mode 100644 index 0000000..6677db0 --- /dev/null +++ b/autoparallel/_testing/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/autoparallel/_testing/_local_tensor.py b/autoparallel/_testing/_local_tensor.py new file mode 100644 index 0000000..278160d --- /dev/null +++ b/autoparallel/_testing/_local_tensor.py @@ -0,0 +1,192 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.nn as nn +from autoparallel.graph_pp_runner import GraphPipelineStage +from torch._C._distributed_c10d import FakeWork, PythonCallbackWork +from torch.distributed import DeviceMesh +from torch.distributed._local_tensor import ( + local_tensor_mode, + LocalIntNode, + LocalRunnerMode, + LocalTensor, + LocalTensorMode, + maybe_disable_local_tensor_mode, +) +from torch.distributed._local_tensor._c10d import local_p2p_op +from torch.distributed.pipelining.stage import InputInfo, PipelineStage +from torch.distributed.tensor import DTensor +from torch.export._unlift import _assign_attr +from torch.export.unflatten import _AttrKind + + +_pg_groups: list[list[int]] = [] + + +def create_local_tensor_mode(dp_ep_mesh: DeviceMesh, pp_rank: int) -> LocalTensorMode: + dp_ep_full_mesh = dp_ep_mesh._layout.remap_to_tensor(dp_ep_mesh._rank_map) + dp_ep_ranks = dp_ep_full_mesh[pp_rank].flatten().tolist() + print(f"Creating local tensor mode for ranks {dp_ep_ranks}") + return LocalTensorMode(frozenset(dp_ep_ranks)) + + +def cache_pp_groups(pp_mesh: DeviceMesh) -> list[list[int]]: + pp_full_mesh = pp_mesh._layout.remap_to_tensor(pp_mesh._rank_map) + pp_groups = [] + for i in range(pp_full_mesh.size(dim=0)): + pp_group = pp_full_mesh[i].tolist() + pp_groups.append(pp_group) + global _pp_groups + _pp_groups = pp_groups + return pp_groups + + +def combine_works(works: list[dist.Work], ctx: str | None = None) -> dist.Work: + def _wait_all(timeout) -> bool: + for w in works: + w.wait() + return True + + return PythonCallbackWork(_wait_all) + + +def get_pp_peer(self: int, peer: int) -> torch.SymInt: + pp_ret = {} + global _pp_groups + for pp_group in _pp_groups: + global_rank = pp_group[self] + global_peer = pp_group[peer] + pp_ret[global_rank] = global_peer + return torch.SymInt(LocalIntNode(pp_ret)) + + +def expand_p2p_ops( + ops: list[dist.P2POp], pp_rank: int, ctx: str | None = None +) -> list[dist.P2POp]: + # Ops where generated from a perspective of pp group where rank 0 is present. + + def multi_isend(tensor, dst=None, group=None, tag=0, group_src=None): + assert group_src is not None, "Expected group rank" + peer = get_pp_peer(pp_rank, group_src) + if not isinstance(tensor, LocalTensor): + tensor = maybe_make_tensor_local(tensor) + works = local_p2p_op(peer, tensor, dist.isend) + return FakeWork() + + def multi_irecv(tensor, src=None, group=None, tag=0, group_src=None): + assert group_src is not None, "Expected group rank" + peer = get_pp_peer(pp_rank, group_src) + assert isinstance(tensor, LocalTensor), "Expected LocalTensor" + works = local_p2p_op(peer, tensor, dist.irecv) + return combine_works(works) + + send_ops = [] + recv_ops = [] + for p2p_op in ops: + op = p2p_op.op + if op is dist.isend: + p2p_op.op = multi_isend + send_ops.append(p2p_op) + elif op is dist.irecv: + p2p_op.op = multi_irecv + recv_ops.append(p2p_op) + else: + raise AssertionError("Unxpected op {op}") + + # Execute send ops first and then recv because the latter are blocking + return send_ops + recv_ops + + +class LocalGraphPipelineStage(GraphPipelineStage): + def log_name(self) -> str: + return ( + f"PP rank {self.group_rank} Stage {self.stage_index} of {self.num_stages}" + ) + + def _get_recv_ops(self, recv_infos: tuple[InputInfo, ...]) -> list[dist.P2POp]: + ops = super()._get_recv_ops(recv_infos) + ops = expand_p2p_ops(ops, self.group_rank, self.log_name() + " _get_recv_ops") + return ops + + def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: + ops = super().get_fwd_send_ops(fwd_chunk_id) + ops = expand_p2p_ops( + ops, self.group_rank, self.log_name() + " get_fwd_send_ops" + ) + return ops + + def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: + ops = super().get_bwd_send_ops(bwd_chunk_id) + ops = expand_p2p_ops( + ops, self.group_rank, self.log_name() + " get_bwd_send_ops" + ) + return ops + + def _get_init_p2p_neighbors_ops(self) -> list[dist.P2POp]: + ops = super()._get_init_p2p_neighbors_ops() + ops = expand_p2p_ops( + ops, self.group_rank, self.log_name() + " _get_init_p2p_neighbors_ops" + ) + return ops + + +def local_tensor_mode_if_enabled( + ltm: LocalTensorMode | None = None, +) -> LocalTensorMode | None: + + for _ in range(2): + if ltm is not None and not ltm._disable: + return ltm + ltm = local_tensor_mode() + + return None + + +def maybe_make_tensor_local( + tensor: torch.Tensor, + ltm: LocalTensorMode | None = None, +) -> torch.Tensor: + ltm = local_tensor_mode_if_enabled(ltm) + if ltm is None: + return tensor + + if isinstance(tensor, LocalTensor): + return tensor + + if isinstance(tensor, DTensor): + tensor._local_tensor = maybe_make_tensor_local(tensor._local_tensor, ltm) + return tensor + + local_tensor = ltm.rank_map(lambda r: tensor.clone().detach()) + local_tensor.requires_grad = tensor.requires_grad + return local_tensor + + +def maybe_make_module_local( + module: nn.Module, + ltm: LocalTensorMode | None = None, +) -> None: + ltm = local_tensor_mode_if_enabled(ltm) + print(f"maybe_make_module_local {ltm.ranks}") + if ltm is None: + return + + for k, v in module.named_parameters(): + _assign_attr( + nn.Parameter( + data=maybe_make_tensor_local(v.data, ltm), + requires_grad=v.requires_grad, + ), + module, + k, + attr_kind=_AttrKind.PARAMETER, + ) + + for k, v in module.named_buffers(): + _assign_attr( + maybe_make_tensor_local(v, ltm), module, k, attr_kind=_AttrKind.BUFFER + ) diff --git a/autoparallel/api.py b/autoparallel/api.py index e902132..0f98592 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -7,11 +7,13 @@ import functools import itertools import warnings -from contextlib import ExitStack, contextmanager +from contextlib import contextmanager, ExitStack from types import MethodType from typing import Any, Callable, Optional, Union import torch + +from autoparallel._passes.graph_partition import partition_joint_with_descriptors from torch._dynamo.functional_export import _dynamo_graph_capture_for_export from torch._functorch.aot_autograd import ( aot_compile_joint_with_descriptors, @@ -29,8 +31,6 @@ from torch.export.unflatten import _AttrKind from torch.fx.experimental.symbolic_shapes import ShapeEnv -from autoparallel._passes.graph_partition import partition_joint_with_descriptors - from .activation_checkpointing import ac_joint_pass from .apply_sharding import apply_sharding_to_model from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast @@ -44,9 +44,9 @@ from .init_weights import hook_params_setters from .optimize_sharding import ShardingOptimizer from .utils import ( - NumericsLogger, _get_device_from_mesh, debug_boxed_nop_preserve_node_meta, + NumericsLogger, ) _APPLY_VIEW_MM_VIEW_PATTERN = False @@ -120,7 +120,7 @@ def _move_to_fake(module, k, device, parameter=True): # can patch the verification logic. @contextmanager def monkey_patch_export_verifier(): - from torch._export.verifier import SpecViolationError, Verifier, final + from torch._export.verifier import final, SpecViolationError, Verifier prior = Verifier._check_graph_module diff --git a/autoparallel/utils.py b/autoparallel/utils.py index 6feeae9..0a17b04 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -8,6 +8,9 @@ import torch import torch.utils._pytree as pytree + +from autoparallel.propagation_rules import generate_dummy_redistribute_costs +from torch.distributed._local_tensor import LocalTensor from torch.distributed._tensor.placement_types import Placement, TensorMeta from torch.distributed.device_mesh import _get_device_handle from torch.distributed.tensor._dtensor_spec import DTensorSpec @@ -22,14 +25,12 @@ from torch.distributed.tensor.placement_types import Replicate from torch.utils._pytree import tree_flatten, tree_map_only -from autoparallel.propagation_rules import generate_dummy_redistribute_costs - from .dtensor_util import get_op_strategy, with_implicit_strategies from .propagation_rules import ( - TENSOR_FACTORY_OPS, _op_partial_rules, _op_rules, remove_invalid_configs, + TENSOR_FACTORY_OPS, ) @@ -466,7 +467,7 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, should_lo if name not in real_params: continue param = real_params[name] - param_logs.append(f"{name=} hash={hash_tensor(param)}") + param_logs.append(f"rank={name=} hash={hash_tensor(param)}") with open(path, "a") as f: f.write("\n".join(param_logs) + "\n") torch.distributed.barrier() @@ -490,7 +491,6 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, should_lo def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, should_log): path = self.dir / "diff.log" - for i in range(num_world_stages): if should_log and i in stage_mods: grad_logs = [] diff --git a/examples/example_ds3_pp.py b/examples/example_ds3_pp.py index 1f682fd..a42b9b6 100644 --- a/examples/example_ds3_pp.py +++ b/examples/example_ds3_pp.py @@ -7,48 +7,35 @@ import logging import os from contextlib import nullcontext +from enum import Enum from typing import Callable, Optional import torch import torch.distributed._tools.fake_collectives import torch.nn as nn -from torch._logging import trace_structured -from torch._subclasses.fake_tensor import FakeTensorMode -from torch.distributed.pipelining.schedules import ( - BACKWARD_INPUT, - BACKWARD_WEIGHT, - FORWARD, - FULL_BACKWARD, - OVERLAP_F_B, - REDUCE_GRAD, - RESHARD, - UNSHARD, - PipelineScheduleMulti, - _PipelineSchedule, - _PipelineScheduleRuntime, - get_schedule_class, +from autoparallel._testing._local_tensor import ( + cache_pp_groups, + create_local_tensor_mode, + LocalGraphPipelineStage, + maybe_make_module_local, + maybe_make_tensor_local, ) -from torch.distributed.pipelining.stage import PipelineStage -from torch.distributed.tensor.placement_types import Replicate, Shard -from torch.fx.experimental.symbolic_shapes import ShapeEnv -from torch.testing._internal.distributed.fake_pg import FakeStore - from autoparallel._testing.models.dsv3 import ( DeepSeekV3Model, DeepSeekV3ModelArgs, DeepSeekV3Stage0, DeepSeekV3StageI, DeepSeekV3StageN, - MoEArgs, dsv3_loss_fn, + MoEArgs, ) from autoparallel.api import AutoParallelPP from autoparallel.graph_pp_runner import ( + get_multiplexed_graph_callables, GraphCallables, GraphMeta, GraphPipelineStage, GraphPPRunner, - get_multiplexed_graph_callables, overlap_fw_bw, stage_backward_input, stage_backward_weight, @@ -59,6 +46,27 @@ stage_unshard, ) from autoparallel.utils import NumericsLogger +from torch._logging import trace_structured +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._local_tensor import LocalRunnerMode, LocalTensorMode +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + _PipelineScheduleRuntime, + BACKWARD_INPUT, + BACKWARD_WEIGHT, + FORWARD, + FULL_BACKWARD, + get_schedule_class, + OVERLAP_F_B, + PipelineScheduleMulti, + REDUCE_GRAD, + RESHARD, + UNSHARD, +) +from torch.distributed.pipelining.stage import PipelineStage +from torch.distributed.tensor.placement_types import Replicate, Shard +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.testing._internal.distributed.fake_pg import FakeStore # Configure logging to show DEBUG messages logging.basicConfig( @@ -128,14 +136,23 @@ def build_pipeline_schedule( return schedule +class RunMode(Enum): + MULTI_PROCESS = "multi_process" + LOCAL_TENSOR = "local_tensor" + FAKE_EVALUATE = "fake_evaluate" + + def __str__(self): + return self.value + + def run_test( - fake_evaluate: bool, + run_mode: RunMode, use_loss_fn: bool, schedule_name: str, rng_seed: Optional[int], logs_dir: str, ): - if not fake_evaluate: + if run_mode is not RunMode.FAKE_EVALUATE: pp_degree = 2 dp_mod_ep_degree = 2 ep_degree = 2 @@ -148,7 +165,7 @@ def run_test( world_size = pp_degree * dp_mod_ep_degree * ep_degree # Initialize process group based on evaluation mode - if fake_evaluate: + if run_mode is RunMode.FAKE_EVALUATE: assert ( "WORLD_SIZE" in os.environ ), "run with torchrun --standalone --nproc-per-node 4" @@ -165,6 +182,16 @@ def run_test( world_size=world_size, ) pp_rank = rank + elif run_mode is RunMode.LOCAL_TENSOR: + assert ( + "WORLD_SIZE" in os.environ + ), "run with torchrun --standalone --nproc-per-node 1" + device = torch.device(f"cuda") + default_pg = torch.distributed.init_process_group( + "fake", + rank=0, + world_size=world_size, + ) else: assert ( "WORLD_SIZE" in os.environ @@ -187,10 +214,6 @@ def run_test( ), ) - # Set pp_rank based on evaluation mode - if not fake_evaluate: - pp_rank = world_mesh["pp"].get_local_rank() - stages_per_rank = 2 total_pp_stages = pp_degree * stages_per_rank @@ -211,7 +234,7 @@ def run_test( seq_len = 1024 - if fake_evaluate: + if run_mode is RunMode.FAKE_EVALUATE: config = DeepSeekV3ModelArgs( vocab_size=102400, max_seq_len=seq_len, @@ -383,7 +406,7 @@ def last_stage_inp_with_loss_fn(): assert len(pp_rank_to_stage_indices) == pp_degree for stages in pp_rank_to_stage_indices.values(): assert len(stages) * pp_degree == len(virtual_pp_stages) - stage_indices_current_pp_rank = pp_rank_to_stage_indices[pp_rank] + if rng_seed: # Compute the ranks to log from # 1. for fw_outs, log from coord [pp_rank_containing_last_stage, 0, 0] @@ -404,21 +427,33 @@ def last_stage_inp_with_loss_fn(): # 2. for weights, log from coords [:, 0, 0] pp_world_size = world_mesh.shape[world_mesh._get_mesh_dim_by_name("pp")] - log_weights_rank_coordinates = [(i, 0, 0) for i in range(pp_world_size)] - should_log_weights = ( - tuple(world_mesh.get_coordinate()) in log_weights_rank_coordinates - ) + # log_weights_rank_coordinates = [(i, 0, 0) for i in range(pp_world_size)] + # should_log_weights = ( + # tuple(world_mesh.get_coordinate()) in log_weights_rank_coordinates + # ) + should_log_weights = True stage_mods: dict[int, torch.nn.Module] = {} stage_graphs: dict[int, GraphCallables] = {} stage_graph_metas: dict[int, GraphMeta] = {} # Step 3. Apply AutoParallel to each logical stage assigned to this pp rank - use_cache = fake_evaluate + use_cache = run_mode is RunMode.FAKE_EVALUATE root_cache = "tmp" os.makedirs(root_cache, exist_ok=True) from autoparallel.api import AutoParallelPPModule - for stage_idx in stage_indices_current_pp_rank: + # Set pp_rank based on evaluation mode + if run_mode is RunMode.LOCAL_TENSOR: + stage_indices = list(range(len(virtual_pp_stages))) + else: + pp_rank = world_mesh["pp"].get_local_rank() + stage_indices = pp_rank_to_stage_indices[pp_rank] + + numerics_logger = None + if rng_seed is not None: + numerics_logger = NumericsLogger(logs_dir) + + for stage_idx in stage_indices: trace_structured( "artifact", metadata_fn=lambda: { @@ -428,7 +463,7 @@ def last_stage_inp_with_loss_fn(): payload_fn=lambda: "placeholder text", ) stage_mod = virtual_pp_stages[stage_idx] - eval_mode = "fake" if fake_evaluate else "real" + eval_mode = "fake" if run_mode is RunMode.FAKE_EVALUATE else "real" stage_file = os.path.join(root_cache, f"stage_{eval_mode}_{stage_idx}.pth") if os.path.exists(stage_file) and use_cache: cache = torch.load(stage_file, weights_only=False) @@ -462,6 +497,7 @@ def last_stage_inp_with_loss_fn(): if use_loss_fn and stage_idx == total_pp_stages - 1 else None ), + numerics_logger=numerics_logger, ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) @@ -494,8 +530,6 @@ def last_stage_inp_with_loss_fn(): torch.save(cache, stage_file) pp_mod.to_empty(device=device) - # run weight init on our sharded DTensor params - pp_mod.init_weights(buffer_device=device, seed=rng_seed) # Store each stage's information in stage_mods, stage_graphs, and stage_graph_metas stage_mods[stage_idx] = pp_mod @@ -526,119 +560,175 @@ def last_stage_inp_with_loss_fn(): # Two stages per pp rank assert ( - len(stage_indices_current_pp_rank) + len(stage_indices) == len(stage_mods) == len(stage_graphs) == len(stage_graph_metas) ) + if run_mode is RunMode.LOCAL_TENSOR: + cache_pp_groups(world_mesh["pp"]) + world_size = torch.distributed.get_world_size() num_world_stages = world_size * len(stage_mods) - numerics_logger = None - if rng_seed is not None: - numerics_logger = NumericsLogger(logs_dir) - numerics_logger.log_pp_model_weights( - model, stage_mods, num_world_stages, should_log=should_log_weights - ) - torch.manual_seed(rng_seed) - - stages = [] - # Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata - for pp_stage_idx, pp_stage_mod in stage_mods.items(): - stage = GraphPipelineStage( - pp_stage_mod, - stage_graphs[pp_stage_idx], - stage_graph_metas[pp_stage_idx], - stage_index=pp_stage_idx, - num_stages=len(virtual_pp_stages), - device=device, - input_args=( - shape_inference_input_fn_first_stage() - if pp_stage_idx == 0 - else shape_inference_fn_intermediate_stage() - ), - output_args=( - shape_inference_output_fn_last_stage() - if pp_stage_idx == (total_pp_stages - 1) - else shape_inference_fn_intermediate_stage() - ), - group=world_mesh.get_group("pp"), - numerics_logger=numerics_logger, - should_log_fw_outs=should_log_fw_outs, - ) - stages.append(stage) - - # Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank - schedule = build_pipeline_schedule( - stages=stages, - loss_fn=None, - pipeline_parallel_schedule=schedule_name, - microbatch_size=microbatch_size, - local_batch_size=local_batch_size, - pipeline_parallel_degree=pp_degree, - backward_requires_autograd=False, - scale_grads=rng_seed is None, # In determinism mode, don't scale grads - ) - assert isinstance(schedule, _PipelineScheduleRuntime) - - # Step 6. Override the pipeline runner's action implementations - schedule.register_custom_function(FORWARD, stage_forward) - schedule.register_custom_function(FULL_BACKWARD, stage_full_backward) - schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad) - schedule.register_custom_function(RESHARD, stage_reshard) - schedule.register_custom_function(UNSHARD, stage_unshard) - schedule.register_custom_function(BACKWARD_INPUT, stage_backward_input) - schedule.register_custom_function(BACKWARD_WEIGHT, stage_backward_weight) - if schedule_name == "DualPipeV": - multiplexed_graph_callables = get_multiplexed_graph_callables(stage_graphs) - schedule.register_custom_function( - OVERLAP_F_B, functools.partial(overlap_fw_bw, multiplexed_graph_callables) + def run_pp_rank(pp_rank: int): + maybe_local_context = ( + create_local_tensor_mode(mesh, pp_rank) + if run_mode is RunMode.LOCAL_TENSOR + else nullcontext() ) + with maybe_local_context: + # Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata + stage_indices_current_pp_rank = pp_rank_to_stage_indices[pp_rank] + stages = [] + rank_stage_mods = {} + for pp_stage_idx in stage_indices_current_pp_rank: + pp_stage_mod = stage_mods[pp_stage_idx] + + # Convert module to local if running under local tensor mode + if run_mode is RunMode.LOCAL_TENSOR: + maybe_make_module_local(pp_stage_mod) + should_log_fw_outs = True + + # run weight init on our sharded DTensor params + pp_stage_mod.init_weights(buffer_device=device, seed=rng_seed) + + pipeline_stage_class = ( + LocalGraphPipelineStage + if run_mode is RunMode.LOCAL_TENSOR + else GraphPipelineStage + ) + + stage = pipeline_stage_class( + pp_stage_mod, + stage_graphs[pp_stage_idx], + stage_graph_metas[pp_stage_idx], + stage_index=pp_stage_idx, + num_stages=len(virtual_pp_stages), + device=device, + input_args=( + shape_inference_input_fn_first_stage() + if pp_stage_idx == 0 + else shape_inference_fn_intermediate_stage() + ), + output_args=( + shape_inference_output_fn_last_stage() + if pp_stage_idx == (total_pp_stages - 1) + else shape_inference_fn_intermediate_stage() + ), + group=world_mesh.get_group("pp"), + numerics_logger=numerics_logger, + should_log_fw_outs=should_log_fw_outs, + ) - # Step 7. Register the schedule with the graph runner + # NB: This is clearly a hack. The purpose of it is to override pp rank + # that the stage obtained from the process group. Stage computes peers to + # work with based on group rank. + if run_mode is RunMode.LOCAL_TENSOR: + stage.group_rank = pp_rank + + stages.append(stage) + rank_stage_mods[pp_stage_idx] = pp_stage_mod + + # Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank + schedule = build_pipeline_schedule( + stages=stages, + loss_fn=None, + pipeline_parallel_schedule=schedule_name, + microbatch_size=microbatch_size, + local_batch_size=local_batch_size, + pipeline_parallel_degree=pp_degree, + backward_requires_autograd=False, + scale_grads=rng_seed is None, # In determinism mode, don't scale grads + ) + assert isinstance(schedule, _PipelineScheduleRuntime) + + # Step 6. Override the pipeline runner's action implementations + schedule.register_custom_function(FORWARD, stage_forward) + schedule.register_custom_function(FULL_BACKWARD, stage_full_backward) + schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad) + schedule.register_custom_function(RESHARD, stage_reshard) + schedule.register_custom_function(UNSHARD, stage_unshard) + schedule.register_custom_function(BACKWARD_INPUT, stage_backward_input) + schedule.register_custom_function(BACKWARD_WEIGHT, stage_backward_weight) + if schedule_name == "DualPipeV": + multiplexed_graph_callables = get_multiplexed_graph_callables( + stage_graphs + ) + schedule.register_custom_function( + OVERLAP_F_B, + functools.partial(overlap_fw_bw, multiplexed_graph_callables), + ) - graph_pp_runner = GraphPPRunner(schedule) + if rng_seed is not None: + numerics_logger.log_pp_model_weights( + model, + rank_stage_mods, + num_world_stages, + should_log=should_log_weights, + ) + torch.manual_seed(rng_seed) - # Step 8. Run the whole pipeline once using the graph runner - has_last_stage = (total_pp_stages - 1) in stage_mods - with ( - FakeTensorMode( - allow_non_fake_inputs=True, - shape_env=ShapeEnv(), - ) - if fake_evaluate - else nullcontext() - ): - with torch.no_grad(): - target, losses = ( - (runtime_target_fn(), []) - if has_last_stage and use_loss_fn - else (None, None) + # Step 7. Register the schedule with the graph runner + graph_pp_runner = GraphPPRunner(schedule) + + # Step 8. Run the whole pipeline once using the graph runner + has_last_stage = (total_pp_stages - 1) in rank_stage_mods + maybe_fake_context = ( + FakeTensorMode( + allow_non_fake_inputs=True, + shape_env=ShapeEnv(), + ) + if run_mode is RunMode.FAKE_EVALUATE + else nullcontext() ) - if pp_rank == 0: - x = runtime_input_fn_first_stage() - if rng_seed: - numerics_logger.log_diff( - x.to(torch.float32), prefix="full batch input" + with maybe_fake_context: + with torch.no_grad(): + target, losses = ( + (runtime_target_fn(), []) + if has_last_stage and use_loss_fn + else (None, None) + ) + if pp_rank == 0: + x = runtime_input_fn_first_stage() + if rng_seed: + numerics_logger.log_diff( + x.to(torch.float32), prefix="full batch input" + ) + graph_pp_runner.step( + x, target=target, losses=losses, return_outputs=False + ) + else: + graph_pp_runner.step( + target=target, losses=losses, return_outputs=False + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "pipeline_step_losses", + "encoding": "string", + }, + payload_fn=lambda: f"losses: {losses}", ) - graph_pp_runner.step( - x, target=target, losses=losses, return_outputs=False + + numerics_logger.log_pp_grads( + model, + stage_mods, + num_world_stages, + should_log=should_log_weights, ) - else: - graph_pp_runner.step(target=target, losses=losses, return_outputs=False) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "pipeline_step_losses", - "encoding": "string", - }, - payload_fn=lambda: f"losses: {losses}", - ) - numerics_logger.log_pp_grads( - model, stage_mods, num_world_stages, should_log=should_log_weights - ) + if run_mode is RunMode.LOCAL_TENSOR: + with LocalRunnerMode( + world_size, + pp_degree, + run_pp_rank, + ): + pass + else: + pp_rank = world_mesh["pp"].get_local_rank() + run_pp_rank(pp_rank) print("All good!") @@ -655,9 +745,10 @@ def last_stage_inp_with_loss_fn(): description="Run DeepSeek V3 pipeline parallel example" ) parser.add_argument( - "--fake-evaluate", - action="store_true", - default=False, + "--run-mode", + type=RunMode, + choices=list(RunMode), + default=RunMode.MULTI_PROCESS, help="Use fake evaluation mode with FakeTensorMode (default: False)", ) parser.add_argument( @@ -692,7 +783,7 @@ def last_stage_inp_with_loss_fn(): torch.manual_seed(args.rng_seed) run_test( - fake_evaluate=args.fake_evaluate, + run_mode=args.run_mode, use_loss_fn=args.use_loss_fn, schedule_name=args.schedule_name, rng_seed=args.rng_seed, diff --git a/examples/example_ds3_pp_local_tensor.py b/examples/example_ds3_pp_local_tensor.py new file mode 100644 index 0000000..181f50c --- /dev/null +++ b/examples/example_ds3_pp_local_tensor.py @@ -0,0 +1,617 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import logging +import os +from contextlib import nullcontext +from typing import Callable, Optional + +import torch +import torch.distributed as dist +import torch.distributed._tools.fake_collectives +import torch.nn as nn + +from autoparallel._testing.models.dsv3 import ( + DeepSeekV3Model, + DeepSeekV3ModelArgs, + DeepSeekV3Stage0, + DeepSeekV3StageI, + DeepSeekV3StageN, + MoEArgs, +) +from autoparallel.api import AutoParallelPP +from autoparallel.graph_pp_runner import ( + GraphCallables, + GraphMeta, + GraphPipelineStage, + GraphPPRunner, + stage_forward, + stage_full_backward, + stage_reduce_grad, + stage_reshard, + stage_unshard, +) +from autoparallel.utils import print_rank_by_rank +from examples.example_ds3_pp import build_pipeline_schedule +from torch._C._distributed_c10d import FakeProcessGroup, FakeWork, PythonCallbackWork + +from torch._logging import trace_structured +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed import DeviceMesh +from torch.distributed._local_tensor import ( + local_tensor_mode, + LocalIntNode, + LocalRunnerMode, + LocalTensor, + LocalTensorMode, + maybe_disable_local_tensor_mode, +) +from torch.distributed._local_tensor._c10d import local_p2p_op +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + _PipelineScheduleRuntime, + FORWARD, + FULL_BACKWARD, + get_schedule_class, + PipelineScheduleMulti, + REDUCE_GRAD, + RESHARD, + UNSHARD, +) +from torch.distributed.pipelining.stage import InputInfo, PipelineStage +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import Shard +from torch.export._unlift import _assign_attr +from torch.export.unflatten import _AttrKind +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.testing._internal.distributed.fake_pg import FakeStore + + +# Configure logging to show DEBUG messages +logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +_pg_groups: list[list[int]] = [] + + +def enumerate_pp_groups(pp_mesh: DeviceMesh) -> list[list[int]]: + pp_full_mesh = pp_mesh._layout.remap_to_tensor(pp_mesh._rank_map) + pp_groups = [] + for i in range(pp_full_mesh.size(dim=0)): + pp_group = pp_full_mesh[i].tolist() + pp_groups.append(pp_group) + return pp_groups + + +def combine_works(works: list[dist.Work], ctx: str | None = None) -> dist.Work: + def _wait_all(timeout) -> bool: + for w in works: + w.wait() + return True + + return PythonCallbackWork(_wait_all) + + +def get_pp_peer(self: int, peer: int) -> torch.SymInt: + pp_ret = {} + global _pp_groups + for pp_group in _pp_groups: + global_rank = pp_group[self] + global_peer = pp_group[peer] + pp_ret[global_rank] = global_peer + return torch.SymInt(LocalIntNode(pp_ret)) + + +def expand_p2p_ops( + ops: list[dist.P2POp], pp_rank: int, ctx: str | None = None +) -> list[dist.P2POp]: + # Ops where generated from a perspective of pp group where rank 0 is present. + + def multi_isend(tensor, dst=None, group=None, tag=0, group_src=None): + assert group_src is not None, "Expected group rank" + peer = get_pp_peer(pp_rank, group_src) + if not isinstance(tensor, LocalTensor): + tensor = maybe_make_tensor_local(tensor) + works = local_p2p_op(peer, tensor, dist.isend) + return FakeWork() + + def multi_irecv(tensor, src=None, group=None, tag=0, group_src=None): + assert group_src is not None, "Expected group rank" + peer = get_pp_peer(pp_rank, group_src) + assert isinstance(tensor, LocalTensor), "Expected LocalTensor" + works = local_p2p_op(peer, tensor, dist.irecv) + return combine_works(works) + + send_ops = [] + recv_ops = [] + for p2p_op in ops: + op = p2p_op.op + if op is dist.isend: + p2p_op.op = multi_isend + send_ops.append(p2p_op) + elif op is dist.irecv: + p2p_op.op = multi_irecv + recv_ops.append(p2p_op) + else: + raise AssertionError("Unxpected op {op}") + + # Execute send ops first and then recv because the latter are blocking + return send_ops + recv_ops + + +class LocalGraphPipelineStage(GraphPipelineStage): + def log_name(self) -> str: + return ( + f"PP rank {self.group_rank} Stage {self.stage_index} of {self.num_stages}" + ) + + def _get_recv_ops(self, recv_infos: tuple[InputInfo, ...]) -> list[dist.P2POp]: + ops = super()._get_recv_ops(recv_infos) + ops = expand_p2p_ops(ops, self.group_rank, self.log_name() + " _get_recv_ops") + return ops + + def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: + ops = super().get_fwd_send_ops(fwd_chunk_id) + ops = expand_p2p_ops( + ops, self.group_rank, self.log_name() + " get_fwd_send_ops" + ) + return ops + + def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: + ops = super().get_bwd_send_ops(bwd_chunk_id) + ops = expand_p2p_ops( + ops, self.group_rank, self.log_name() + " get_bwd_send_ops" + ) + return ops + + def _get_init_p2p_neighbors_ops(self) -> list[dist.P2POp]: + ops = super()._get_init_p2p_neighbors_ops() + ops = expand_p2p_ops( + ops, self.group_rank, self.log_name() + " _get_init_p2p_neighbors_ops" + ) + return ops + + +def run_test(run_local: bool, debug_numerics: Optional[bool]): + pp_degree = 2 + dp_mod_ep_degree = 2 + ep_degree = 2 + + dp_degree = dp_mod_ep_degree * ep_degree + world_size = pp_degree * dp_mod_ep_degree * ep_degree + + # Initialize process group based on evaluation mode + if run_local: + assert ( + "WORLD_SIZE" in os.environ + ), "run with torchrun --standalone --nproc-per-node 1" + device = torch.device(f"cuda") + default_pg = torch.distributed.init_process_group( + "fake", + rank=0, + world_size=world_size, + ) + else: + assert ( + "WORLD_SIZE" in os.environ + ), "run with torchrun --standalone --nproc-per-node 8" + assert ( + int(os.getenv("WORLD_SIZE")) == world_size + ), "Need at least 8 GPUs for real evaluation" + local_rank = int(os.getenv("LOCAL_RANK")) + device = torch.device(f"cuda:{local_rank}") + default_pg = torch.distributed.init_process_group(backend="nccl") + + # Initialize device mesh (common for both modes) + world_mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (pp_degree, dp_mod_ep_degree, ep_degree), + mesh_dim_names=( + "pp", + "dp_mod_ep", + "ep", + ), + ) + + stages_per_rank = 2 + total_pp_stages = pp_degree * stages_per_rank + + # This is the spmd mesh to be used for tracing + mesh = world_mesh[("dp_mod_ep", "ep")] + + global_batch_size = 32 * dp_degree + # Batch size that will be supplied to the schedule and will be broken down into microbatches + local_batch_size = global_batch_size // dp_degree + n_microbatches = 16 + # Batch size with which the spmd graphs will actually be executed + microbatch_size = local_batch_size // n_microbatches + assert ( + microbatch_size >= 1 + ), f"invalid config {local_batch_size=}, {n_microbatches=}" + # Batch size to be used for spmd tracing + spmd_batch_size = microbatch_size * dp_degree + + seq_len = 1024 + + config = DeepSeekV3ModelArgs( + vocab_size=2048, + max_seq_len=seq_len, + dim=256, + inter_dim=1024, + moe_inter_dim=256, + n_layers=4, + n_dense_layers=0, # 1, + n_heads=16, + moe_args=MoEArgs( + num_experts=4, + num_shared_experts=2, + top_k=2, + score_func="softmax", + route_norm=False, + score_before_experts=False, + mesh=mesh, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ) + + # Step 0. Construct the model and extract its layers to create stages from. + with torch.device("meta"): + model = DeepSeekV3Model(config).bfloat16() + embed, layers, norm, output = list(model.children()) + items = list(layers.items()) + assert len(items) == config.n_layers + n_layers_per_rank = len(items) // total_pp_stages + layers = [ + nn.ModuleDict(items[i : i + n_layers_per_rank]) + for i in range(0, len(items), n_layers_per_rank) + ] + assert len(layers) == total_pp_stages + for lst in layers: + assert len(lst) * len(layers) == config.n_layers + + def tracing_input_fn(): + return torch.randint( + 0, + config.vocab_size, + (spmd_batch_size, seq_len), + device=device, + ) + + def tracing_input_fn_after_first_stage(): + return torch.randn( + (spmd_batch_size, seq_len, config.dim), + device=device, + dtype=torch.bfloat16, + requires_grad=True, + ) + + def runtime_input_fn(): + return torch.randint( + 0, + config.vocab_size, + (local_batch_size, seq_len), + device=device, + ) + + def shape_inference_input_fn(): + return torch.randint( + 0, + config.vocab_size, + (microbatch_size, seq_len), + device="meta", + ) + + def shape_inference_input_fn_after_first_stage(): + return torch.randn( + (microbatch_size, seq_len, config.dim), + device="meta", + dtype=torch.bfloat16, + requires_grad=True, + ) + + def shape_inference_output_fn_last_stage(): + return torch.randn( + (microbatch_size, seq_len, config.vocab_size), + device="meta", + dtype=torch.bfloat16, + requires_grad=True, + ) + + # Step 1. Construct the logical pipeline stages + with torch.device("meta"): + virtual_pp_stages = [DeepSeekV3Stage0(embed, layers[0], config)] + for i in range(1, total_pp_stages - 1): + virtual_pp_stages.append(DeepSeekV3StageI(layers[i], config)) + virtual_pp_stages.append( + DeepSeekV3StageN(layers[total_pp_stages - 1], norm, output, config) + ) + + # Step 2. Assign each logical stage(s) to pp ranks for Interleaved1F1B schedule + pp_rank_to_stage_indices: dict[int, list[int]] = { + rank: [rank + i * pp_degree for i in range(stages_per_rank)] + for rank in range(pp_degree) + } + assert len(pp_rank_to_stage_indices) == pp_degree + for stages in pp_rank_to_stage_indices.values(): + assert len(stages) * pp_degree == len(virtual_pp_stages) + + stage_mods: dict[int, torch.nn.Module] = {} + stage_graphs: dict[int, GraphCallables] = {} + stage_graph_metas: dict[int, GraphMeta] = {} + + # Step 3. Apply AutoParallel to each logical stage + from autoparallel.api import AutoParallelPPModule + + for stage_idx, stage_mod in enumerate(virtual_pp_stages): + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"begin_tracing_stage_{stage_idx}", + "encoding": "string", + }, + payload_fn=lambda: "placeholder text", + ) + + if stage_idx == 0: + input_fn = tracing_input_fn + else: + input_fn = tracing_input_fn_after_first_stage + with AutoParallelPP( + stage_mod, input_fn, mesh, dynamic=True, compile=False + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + # x_sharding = (Shard(0), Replicate()) + x_sharding = (Shard(0), Shard(0)) + + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + + sharding_placement = autop.optimize_placement(verbose=False) + cache = autop.apply_placement_pp(sharding_placement) + graph_callables = cache["graph_callables"] + graph_meta = cache["graph_meta"] + pp_mod = AutoParallelPPModule( + cache["sharded_param_dict"], + cache["sharded_buffer_dict"], + autop.init_weights_model, + ) + + pp_mod.to_empty(device=device) + pp_mod.init_weights(buffer_device=device) + + # Store each stage's information in stage_mods, stage_graphs, and stage_graph_metas + stage_mods[stage_idx] = pp_mod + stage_graphs[stage_idx] = GraphCallables( + fw=graph_callables["fw"], + full_bw=graph_callables["full_bw"], + bw_dI=graph_callables["bw_dI"], + bw_dW=graph_callables["bw_dW"], + unshard=graph_callables["unshard"], + reduce_grad=graph_callables["reduce_grad"], + ) + stage_graph_metas[stage_idx] = GraphMeta( + num_mutate_inputs=graph_meta["num_mutate_inputs"], + num_user_outputs=graph_meta["num_user_outputs"], + num_symints_saved_for_bw=graph_meta["num_symints_saved_for_bw"], + num_params=graph_meta["num_params"], + num_buffers=graph_meta["num_buffers"], + num_input_grads=graph_meta["num_input_grads"], + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"end_tracing_stage_{stage_idx}", + "encoding": "string", + }, + payload_fn=lambda: "placeholder text", + ) + + # At this point all stages have been compiles and parallelized. + # NB: PP rank code + if run_local: + global _pp_groups + _pp_groups = enumerate_pp_groups(world_mesh["pp"]) + + def run_pp_rank(pp_rank: int): + maybe_local_context = ( + LocalTensorMode(world_size) if run_local else nullcontext() + ) + with maybe_local_context: + # Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata + stage_indices_current_pp_rank = pp_rank_to_stage_indices[pp_rank] + stages = [] + for pp_stage_idx in stage_indices_current_pp_rank: + pp_stage_mod = stage_mods[pp_stage_idx] + + # Convert module to local if running under local tensor mode + maybe_make_module_local(pp_stage_mod) + + args = ( + pp_stage_mod, + stage_graphs[pp_stage_idx], + stage_graph_metas[pp_stage_idx], + ) + kwargs = { + "stage_index": pp_stage_idx, + "num_stages": len(virtual_pp_stages), + "device": device, + "input_args": ( + shape_inference_input_fn() + if pp_stage_idx == 0 + else shape_inference_input_fn_after_first_stage() + ), + "output_args": ( + shape_inference_output_fn_last_stage() + if pp_stage_idx == (len(virtual_pp_stages) - 1) + else shape_inference_input_fn_after_first_stage() + ), + "group": world_mesh.get_group("pp"), + } + stage = ( + LocalGraphPipelineStage(*args, **kwargs) + if run_local + else GraphPipelineStage(*args, **kwargs) + ) + + # NB: This is clearly a hack. The purpose of it is to override pp rank + # that the stage obtained from the process group. Stage computes peers to + # work with based on group rank. + if run_local: + stage.group_rank = pp_rank + + stages.append(stage) + + # Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank + schedule = build_pipeline_schedule( + stages=stages, + loss_fn=None, + pipeline_parallel_schedule="Interleaved1F1B", + microbatch_size=microbatch_size, + local_batch_size=local_batch_size, + pipeline_parallel_degree=pp_degree, + backward_requires_autograd=False, + ) + + assert isinstance(schedule, _PipelineScheduleRuntime) + + # Step 6. Override the pipeline runner's action implementations + numerics_logs = [] + schedule.register_custom_function( + FORWARD, + functools.partial(stage_forward, numerics_logs=numerics_logs), + ) + schedule.register_custom_function(FULL_BACKWARD, stage_full_backward) + schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad) + schedule.register_custom_function(RESHARD, stage_reshard) + schedule.register_custom_function(UNSHARD, stage_unshard) + + # Step 7. Register the schedule with the graph runner + graph_pp_runner = GraphPPRunner(schedule) + + # Step 8. Run the whole pipeline once using the graph runner + with torch.no_grad(): + if pp_rank == 0: + x = runtime_input_fn() + graph_pp_runner.step(x) + else: + graph_pp_runner.step() + + if debug_numerics: + print_rank_by_rank("\n".join(numerics_logs)) + + # breakpoint() + if run_local: + with LocalRunnerMode( + world_size, + pp_degree, + run_pp_rank, + ): + pass + else: + pp_rank = world_mesh["pp"].get_local_rank() + run_pp_rank(pp_rank) + + print("All good!") + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.cuda.synchronize() + torch.distributed.destroy_process_group() + + +def local_tensor_mode_if_enabled( + ltm: LocalTensorMode | None = None, +) -> LocalTensorMode | None: + + for _ in range(2): + if ltm is not None and not ltm._disable: + return ltm + ltm = local_tensor_mode() + + return None + + +def maybe_make_tensor_local( + tensor: torch.Tensor, + ltm: LocalTensorMode | None = None, +) -> torch.Tensor: + ltm = local_tensor_mode_if_enabled(ltm) + if ltm is None: + return tensor + + if isinstance(tensor, LocalTensor): + return tensor + + if isinstance(tensor, DTensor): + tensor._local_tensor = maybe_make_tensor_local(tensor._local_tensor, ltm) + return tensor + + local_tensor = ltm.rank_map(lambda r: tensor.clone().detach()) + local_tensor.requires_grad = tensor.requires_grad + return local_tensor + + +def maybe_make_module_local( + module: nn.Module, + ltm: LocalTensorMode | None = None, +) -> None: + ltm = local_tensor_mode_if_enabled(ltm) + if ltm is None: + return + + for k, v in module.named_parameters(): + _assign_attr( + nn.Parameter( + data=maybe_make_tensor_local(v.data, ltm), + requires_grad=v.requires_grad, + ), + module, + k, + attr_kind=_AttrKind.PARAMETER, + ) + + for k, v in module.named_buffers(): + _assign_attr( + maybe_make_tensor_local(v, ltm), module, k, attr_kind=_AttrKind.BUFFER + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Run DeepSeek V3 pipeline parallel example" + ) + parser.add_argument( + "--run-local", + action="store_true", + default=False, + help="Use local tensor mode (default: False)", + ) + parser.add_argument( + "--rng-seed", + type=int, + default=None, + help="Use a specific rng seed and deterministic algorithms for run-to-run invariance (default: None).", + ) + args = parser.parse_args() + + if args.rng_seed is not None: + torch.use_deterministic_algorithms(True) + torch.manual_seed(args.rng_seed) + + run_test(run_local=args.run_local, debug_numerics=args.rng_seed is not None) + +# PYTHONPATH=. torchrun --standalone --nproc-per-node 8 examples/example_ds3_pp_local_tensor.py -- --rng-seed 1 +# PYTHONPATH=. torchrun --standalone --nproc-per-node 1 examples/example_ds3_pp_local_tensor.py -- --rng-seed 1 --run-local diff --git a/examples/run_ds3_numerics_check.py b/examples/run_ds3_numerics_check.py index 65075de..9249aad 100644 --- a/examples/run_ds3_numerics_check.py +++ b/examples/run_ds3_numerics_check.py @@ -6,6 +6,7 @@ """ Script to run DS3 numerics check by comparing outputs from local_map and pipeline parallel. """ +import os import shutil import subprocess import tempfile @@ -32,6 +33,11 @@ def main(args): # Create a temporary directory temp_dir = tempfile.mkdtemp(prefix="ds3_numerics_check_") print(f"Created temporary directory: {temp_dir}") + repo_dir = Path(__file__).parent.parent + if "PYTHONPATH" in os.environ: + os.environ["PYTHONPATH"] = os.environ["PYTHONPATH"] + f":{repo_dir}" + else: + os.environ["PYTHONPATH"] = f"{repo_dir}" try: examples_dir = Path(__file__).parent @@ -48,6 +54,12 @@ def main(args): cmd2 = f"torchrun --standalone --nproc-per-node 8 {examples_dir}/example_ds3_pp.py --rng-seed 42 --schedule-name={schedule_name}" run_command(cmd2, temp_dir) + print("\n" + "=" * 80) + print("Running PP example with Local Tensor...") + print("=" * 80) + cmd3 = f"torchrun --standalone --nproc-per-node 1 {examples_dir}/example_ds3_pp.py --rng-seed 42 --schedule-name={schedule_name} --run-mode=local_tensor" + run_command(cmd3, temp_dir) + out_dir = Path(temp_dir) / "out" if not out_dir.exists(): raise RuntimeError(f"Output directory {out_dir} does not exist") @@ -56,11 +68,13 @@ def main(args): print("Comparing weights.log files...") print("=" * 80) run_command("diff out/0/weights.log out/1/pp_weights.log", temp_dir) + run_command("diff out/0/weights.log out/2/pp_weights.log", temp_dir) print("\n" + "=" * 80) print("Comparing diff.log files...") print("=" * 80) run_command("diff out/0/diff.log out/1/diff.log", temp_dir) + run_command("diff out/0/diff.log out/2/diff.log", temp_dir) print("\n" + "=" * 80) print("Numerics check completed successfully!")