From 05202f34992c4b6fb0d769b2f177166e58136ba1 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Thu, 13 Feb 2025 13:00:52 -0800 Subject: [PATCH] checkpointing/PGTransport: add NCCL/Gloo transport for checkpoints --- torchft/checkpointing/http_transport_test.py | 43 +++- torchft/checkpointing/pg_transport.py | 247 +++++++++++++++++++ torchft/checkpointing/pg_transport_test.py | 57 +++++ torchft/checkpointing/transport_test.py | 148 +++++++++++ torchft/process_group.py | 8 +- 5 files changed, 490 insertions(+), 13 deletions(-) create mode 100644 torchft/checkpointing/pg_transport.py create mode 100644 torchft/checkpointing/pg_transport_test.py create mode 100644 torchft/checkpointing/transport_test.py diff --git a/torchft/checkpointing/http_transport_test.py b/torchft/checkpointing/http_transport_test.py index 6c297730..00d379e0 100644 --- a/torchft/checkpointing/http_transport_test.py +++ b/torchft/checkpointing/http_transport_test.py @@ -6,8 +6,8 @@ import urllib.error from datetime import timedelta -from typing import Any, Dict -from unittest import TestCase +from typing import Dict +from unittest import TestCase, skipUnless from unittest.mock import MagicMock import torch @@ -15,17 +15,14 @@ from torchft.checkpointing.http_transport import HTTPTransport from torchft.checkpointing.http_transport_bench import main as bench_main +from torchft.checkpointing.transport import CheckpointTransport +from torchft.checkpointing.transport_test import ( + assertStateDictEqual, + run_multi_recovery_test, +) class TestHTTPTransport(TestCase): - def assertStateDictEqual(self, a: Dict[str, object], b: Dict[str, object]) -> None: - for k, v1 in a.items(): - v2 = b[k] - if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor): - torch.testing.assert_close(v1.cpu(), v2.cpu()) - else: - self.assertEqual(v1, v2) - @parameterized.expand( [ ("no chunks", 0), @@ -59,7 +56,7 @@ def test_checkpoint_server(self, name: str, num_chunks: int) -> None: out = server.recv_checkpoint( src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10) ) - self.assertStateDictEqual(out, expected) + assertStateDictEqual(self, out, expected) # test timeout with self.assertRaisesRegex(urllib.error.URLError, r"urlopen error"): @@ -114,6 +111,30 @@ def test_checkpoint_server_locking(self) -> None: server.shutdown() + def test_multi_http_transport_cpu(self) -> None: + device = torch.device("cpu") + + def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]: + return HTTPTransport( + timeout=timedelta(seconds=10), + num_chunks=0, + ) + + run_multi_recovery_test(self, init, device=device) + + # pyre-fixme[56]: Pyre was not able to infer the type of the decorator + @skipUnless(torch.cuda.is_available(), "CUDA is not available") + def test_multi_http_transport_cuda(self) -> None: + device = torch.device("cuda") + + def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]: + return HTTPTransport( + timeout=timedelta(seconds=10), + num_chunks=0, + ) + + run_multi_recovery_test(self, init, device=device) + def test_benchmark(self) -> None: bench_main( [ diff --git a/torchft/checkpointing/pg_transport.py b/torchft/checkpointing/pg_transport.py new file mode 100644 index 00000000..cd07771f --- /dev/null +++ b/torchft/checkpointing/pg_transport.py @@ -0,0 +1,247 @@ +import logging +import pickle +import time +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import timedelta +from typing import Generator, List, Tuple, TypeVar, Union, cast + +import torch +from torch.distributed import Work +from torch.distributed.tensor import DTensor, _DTensorSpec +from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten + +from torchft.checkpointing.transport import CheckpointTransport +from torchft.process_group import ProcessGroup + +logger: logging.Logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +@dataclass +class _TensorMeta: + """ + This is the metadata for a tensor that is used to transfer checkpoints. + It contains the shape, the dtype, the storage offset and the stride of the + tensor. + + This must be pickleable so that it can be sent over the wire. + """ + + shape: torch.Size + dtype: torch.dtype + storage_offset: int + stride: Tuple[int, ...] + nbytes: int + + +@dataclass +class _DTensorMeta: + """ + This is the metadata for a DTensor that is used to transfer checkpoints. + It contains the metadata for the local tensor and the spec of the DTensor. + + This must be pickleable so that it can be sent over the wire. + """ + + local: _TensorMeta + spec: _DTensorSpec + + +@dataclass +class _StateDictMeta: + """ + This is the metadata for a state dict that is used to transfer checkpoints. + It contains the step, the pytree spec of the state dict and the metadata for + each tensor in the state dict. + + This must be pickleable so that it can be sent over the wire. + + Args: + step: the step of the checkpoint to verify consistency + treespec: the pytree spec of the state dict + non_tensor_leaves: the metadata for each tensor in the state dict and any + non-tensor leaves in the state dict + """ + + step: int + treespec: TreeSpec + non_tensor_leaves: List[Union[object, _TensorMeta, _DTensorMeta]] + + +@contextmanager +def _timeit(name: str) -> Generator[None, None, None]: + start = time.perf_counter() + yield + dur = time.perf_counter() - start + logger.info(f"{name} took {dur}s") + + +def _prepare_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, _TensorMeta]: + return ( + _cast_tensor(tensor, torch.uint8), + _TensorMeta( + shape=tensor.shape, + dtype=tensor.dtype, + storage_offset=cast(int, tensor.storage_offset()), + stride=tensor.stride(), + nbytes=tensor.untyped_storage().nbytes(), + ), + ) + + +def _prepare_state_dict( + state_dict: object, + step: int, + device: torch.device, +) -> Tuple[_StateDictMeta, List[torch.Tensor]]: + leaves, treespec = tree_flatten(state_dict) + + non_tensor_leaves = [] + tensors = [] + for v in leaves: + if isinstance(v, DTensor): + tensor, tensor_meta = _prepare_tensor(v._local_tensor) + + tensors.append(tensor) + + non_tensor_leaves.append( + _DTensorMeta( + local=tensor_meta, + spec=v._spec, + ) + ) + elif isinstance(v, torch.Tensor): + tensor, tensor_meta = _prepare_tensor(v) + tensors.append(tensor) + non_tensor_leaves.append(tensor_meta) + else: + non_tensor_leaves.append(v) + + return ( + _StateDictMeta( + step=step, + treespec=treespec, + non_tensor_leaves=non_tensor_leaves, + ), + tensors, + ) + + +def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """ + Casts the underlying storage to a tensor of the given dtype. + + The returned tensor will be of size ``storage.nbytes``. + + This works for all datatypes and supports strided/offset tensors with the + caveat that the cast tensor may be larger than the original tensor due to + the differences in striding. + """ + storage = tensor.untyped_storage() + ret = torch.tensor(storage, dtype=dtype, device=tensor.device) + assert ret.untyped_storage() is storage, "storage should be the same" + return ret + + +class PGTransport(CheckpointTransport[T]): + """ + This is a checkpoint transport that uses the process group to transfer checkpoints. + This allows for fast recovery of workers by fetching the current weights + from an existing worker. + Args: + state_dict: a callable that returns the state dict to be transferred + """ + + def __init__( + self, pg: ProcessGroup, timeout: timedelta, device: torch.device + ) -> None: + self._work: List[Work] = [] + self._pg = pg + self._timeout = timeout + self._device = device + + def metadata(self) -> str: + return "" + + def disallow_checkpoint(self) -> None: + pass + + def send_checkpoint( + self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta + ) -> None: + with _timeit("preparing state_dict"): + meta, tensors = _prepare_state_dict(state_dict, step, device=self._device) + + work = [] + + with _timeit("send pickle"): + buf = pickle.dumps(meta) + len_t = torch.tensor([len(buf)], dtype=torch.int64, device=self._device) + buf_t = torch.frombuffer(buf, dtype=torch.uint8).to(self._device) + for dst_rank in dst_ranks: + work.append(self._pg.send([len_t], dst_rank, tag=1)) + work.append(self._pg.send([buf_t], dst_rank, tag=2)) + + with _timeit("send tensors"): + for i, t in enumerate(tensors): + t = t.to(self._device) + for dst_rank in dst_ranks: + work.append(self._pg.send([t], dst_rank, tag=3 + i)) + + # allow 3 concurrent transfers at a time to avoid OOMs + while len(work) > (3 * len(dst_ranks)): + work.pop(0).wait(timeout) + + for w in work: + w.wait(timeout) + + def recv_checkpoint( + self, src_rank: int, metadata: str, step: int, timeout: timedelta + ) -> T: + len_t = torch.zeros(1, dtype=torch.int64, device=self._device) + self._pg.recv([len_t], src_rank, tag=1).wait(timeout) + length = cast(int, len_t.item()) + + assert length > 0, f"invalid metadata length {length=}" + + buf = torch.empty(length, dtype=torch.uint8, device=self._device) + self._pg.recv([buf], src_rank, tag=2).wait(timeout) + + meta: _StateDictMeta = pickle.loads(buf.cpu().numpy().tobytes()) + assert meta.step == step + + i: int = 0 + + def recv(v: _TensorMeta) -> torch.Tensor: + nonlocal i + + t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device) + # TODO: parallelize receives + self._pg.recv([t], src_rank, tag=3 + i).wait(timeout) + i += 1 + + # TODO: allow in place receives to avoid having to copy to cpu to + # avoid OOMs + t = t.cpu() + + return torch.as_strided( + t.view(v.dtype), + size=v.shape, + stride=v.stride, + storage_offset=v.storage_offset, + ) + + values = [] + for v in meta.non_tensor_leaves: + if isinstance(v, _TensorMeta): + values.append(recv(v)) + elif isinstance(v, _DTensorMeta): + tensor = recv(v.local) + # pyre-fixme[29]: DTensor is not a function + values.append(DTensor(tensor, v.spec, requires_grad=False)) + else: + values.append(v) + + return tree_unflatten(values, meta.treespec) diff --git a/torchft/checkpointing/pg_transport_test.py b/torchft/checkpointing/pg_transport_test.py new file mode 100644 index 00000000..a7b9c123 --- /dev/null +++ b/torchft/checkpointing/pg_transport_test.py @@ -0,0 +1,57 @@ +from datetime import timedelta +from typing import Dict +from unittest import TestCase, skipUnless + +import torch +from torch.distributed import TCPStore + +from torchft.checkpointing.pg_transport import PGTransport +from torchft.checkpointing.transport import CheckpointTransport +from torchft.checkpointing.transport_test import run_multi_recovery_test +from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo + + +class PGTransportTest(TestCase): + def test_pg_transport_gloo(self) -> None: + store: TCPStore = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + device: torch.device = torch.device("cpu") + + def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]: + pg = ProcessGroupGloo() + pg.configure( + store_addr=f"localhost:{store.port}/prefix", + rank=rank, + world_size=world_size, + ) + + return PGTransport[Dict[str, object]]( + pg, timeout=timedelta(seconds=10), device=device + ) + + run_multi_recovery_test(self, init, device=device) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @skipUnless(torch.cuda.device_count() >= 3, "need three CUDA devices") + def test_pg_transport_baby_nccl(self) -> None: + store: TCPStore = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + device: torch.device = torch.device("cuda") + + def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]: + torch.cuda.set_device(rank) + + pg = ProcessGroupBabyNCCL() + pg.configure( + store_addr=f"localhost:{store.port}/prefix", + rank=rank, + world_size=world_size, + ) + + return PGTransport[Dict[str, object]]( + pg, timeout=timedelta(seconds=10), device=device + ) + + run_multi_recovery_test(self, init, device=device) diff --git a/torchft/checkpointing/transport_test.py b/torchft/checkpointing/transport_test.py new file mode 100644 index 00000000..5601db6b --- /dev/null +++ b/torchft/checkpointing/transport_test.py @@ -0,0 +1,148 @@ +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import timedelta +from typing import Callable, Dict, List +from unittest import TestCase + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DeviceMesh, DTensor, distribute_tensor + +from torchft.checkpointing.transport import CheckpointTransport + +TIMEOUT_REGEX = r"(Timed out|timed out|timeout|time out)" + + +def assertStateDictEqual( + self: TestCase, a: Dict[str, object], b: Dict[str, object] +) -> None: + for k, v1 in a.items(): + v2 = b[k] + if isinstance(v1, DTensor) and isinstance(v2, DTensor): + torch.testing.assert_close(v1._local_tensor, v2._local_tensor) + self.assertEqual(v1._spec, v2._spec) + elif isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor): + torch.testing.assert_close(v1.cpu(), v2.cpu()) + else: + self.assertEqual(v1, v2) + + +def run_multi_recovery_test( + self: TestCase, + init_transport: Callable[[int, int], CheckpointTransport[Dict[str, object]]], + device: torch.device, +) -> None: + """ + This runs multi node recovery tests for a given transport function. + + This tests send/recv in a 3 node setup, with all and some workers recovering + and also tests timeout behavior. + """ + WORLD_SIZE: int = 3 + + # barrier is used to simulate quorum/allreduce barriers + barrier: threading.Barrier = threading.Barrier(WORLD_SIZE) + metadata: str = "" + + dist.init_process_group( + backend="gloo", rank=0, world_size=1, store=dist.HashStore() + ) + + device_mesh = DeviceMesh("cpu", 1) + tensor = torch.randn(4, 4) + dtensor: DTensor = distribute_tensor(tensor, device_mesh, []) + + def run(rank: int) -> CheckpointTransport[Dict[str, object]]: + transport = init_transport(rank, WORLD_SIZE) + + if rank == 0: + nonlocal metadata + metadata = transport.metadata() + + barrier.wait() + + state_dict: Dict[str, object] = { + "rank": torch.tensor([1, 2, 3], device=device), + "str": "str", + "int": 1234, + "dtensor": dtensor, + } + + # 3 node recovery + if rank == 0: + transport.send_checkpoint( + dst_ranks=[1, 2], + step=1, + state_dict=state_dict, + timeout=timedelta(seconds=10), + ) + else: + got = transport.recv_checkpoint( + src_rank=0, metadata=metadata, step=1, timeout=timedelta(seconds=10) + ) + assertStateDictEqual(self, got, state_dict) + + barrier.wait() + transport.disallow_checkpoint() + + # 2 node recovery + if rank == 0: + transport.send_checkpoint( + dst_ranks=[2], + step=2, + state_dict=state_dict, + timeout=timedelta(seconds=10), + ) + elif rank == 2: + got = transport.recv_checkpoint( + src_rank=0, metadata=metadata, step=2, timeout=timedelta(seconds=10) + ) + assertStateDictEqual(self, got, state_dict) + + barrier.wait() + transport.disallow_checkpoint() + + # timeout test + if rank == 2: + with self.assertRaisesRegex(Exception, TIMEOUT_REGEX): + transport.recv_checkpoint( + src_rank=0, + metadata=metadata, + step=3, + timeout=timedelta(milliseconds=10), + ) + + # Make sure send completes quickly. + # If the transport is async (such as with HTTP) this may just return + # immediately. + try: + transport.send_checkpoint( + dst_ranks=[0], + step=4, + state_dict=state_dict, + timeout=timedelta(seconds=10), + ) + except Exception: + with self.assertRaisesRegex(Exception, TIMEOUT_REGEX): + raise + + return transport + + with ThreadPoolExecutor(max_workers=WORLD_SIZE) as executor: + results = [] + for i in range(WORLD_SIZE): + results.append(executor.submit(run, i)) + + transports = [] + + try: + for fut in as_completed(results, timeout=10.0): + transports.append(fut.result()) + except Exception as e: + print(e) + raise + + for transport in transports: + transport.shutdown() + + dist.destroy_process_group() diff --git a/torchft/process_group.py b/torchft/process_group.py index c1b67f0c..6e368f81 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -966,13 +966,17 @@ def _worker( next_op_id += 1 elif cmd == "wait": op_id: int = op[1] + timeout: Optional[timedelta] = op[2] metadata = work[op_id] with metadata.set_stream(): # With WorkNCCL this makes the stream wait not the CPU when # no timeout is passed. - metadata.work.wait() + if timeout is not None: + metadata.work.wait(timeout) + else: + metadata.work.wait() # Register event on the stream that we can pass to the main # process. @@ -1051,7 +1055,7 @@ def _wait(self, op_id: int, timeout: Optional[timedelta] = None) -> bool: self._assert_alive() assert self._tx is not None - self._tx.put(("wait", op_id), timeout=self._timeout) + self._tx.put(("wait", op_id, timeout), timeout=self._timeout) assert self._rx is not None op_id, event = cast(