Skip to content

Commit a8ee556

Browse files
committed
checkpointing/PGTransport: add NCCL/Gloo transport for checkpoints
1 parent ca8c540 commit a8ee556

File tree

5 files changed

+463
-13
lines changed

5 files changed

+463
-13
lines changed

torchft/checkpointing/http_transport_test.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,23 @@
66

77
import urllib.error
88
from datetime import timedelta
9-
from typing import Any, Dict
10-
from unittest import TestCase
9+
from typing import Dict
10+
from unittest import TestCase, skipUnless
1111
from unittest.mock import MagicMock
1212

1313
import torch
1414
from parameterized import parameterized
1515

1616
from torchft.checkpointing.http_transport import HTTPTransport
1717
from torchft.checkpointing.http_transport_bench import main as bench_main
18+
from torchft.checkpointing.transport import CheckpointTransport
19+
from torchft.checkpointing.transport_test import (
20+
assertStateDictEqual,
21+
run_multi_recovery_test,
22+
)
1823

1924

2025
class TestHTTPTransport(TestCase):
21-
def assertStateDictEqual(self, a: Dict[str, object], b: Dict[str, object]) -> None:
22-
for k, v1 in a.items():
23-
v2 = b[k]
24-
if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
25-
torch.testing.assert_close(v1.cpu(), v2.cpu())
26-
else:
27-
self.assertEqual(v1, v2)
28-
2926
@parameterized.expand(
3027
[
3128
("no chunks", 0),
@@ -59,7 +56,7 @@ def test_checkpoint_server(self, name: str, num_chunks: int) -> None:
5956
out = server.recv_checkpoint(
6057
src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10)
6158
)
62-
self.assertStateDictEqual(out, expected)
59+
assertStateDictEqual(self, out, expected)
6360

6461
# test timeout
6562
with self.assertRaisesRegex(urllib.error.URLError, r"urlopen error"):
@@ -114,6 +111,30 @@ def test_checkpoint_server_locking(self) -> None:
114111

115112
server.shutdown()
116113

114+
def test_multi_http_transport_cpu(self) -> None:
115+
device = torch.device("cpu")
116+
117+
def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
118+
return HTTPTransport(
119+
timeout=timedelta(seconds=10),
120+
num_chunks=0,
121+
)
122+
123+
run_multi_recovery_test(self, init, device=device)
124+
125+
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
126+
@skipUnless(torch.cuda.is_available(), "CUDA is not available")
127+
def test_multi_http_transport_cuda(self) -> None:
128+
device = torch.device("cuda")
129+
130+
def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
131+
return HTTPTransport(
132+
timeout=timedelta(seconds=10),
133+
num_chunks=0,
134+
)
135+
136+
run_multi_recovery_test(self, init, device=device)
137+
117138
def test_benchmark(self) -> None:
118139
bench_main(
119140
[
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import logging
2+
import pickle
3+
import time
4+
from contextlib import contextmanager
5+
from dataclasses import dataclass
6+
from datetime import timedelta
7+
from typing import Generator, List, Tuple, TypeVar, cast
8+
9+
import torch
10+
from torch.distributed import Work
11+
from torch.distributed.tensor import DTensor
12+
from torch.utils._pytree import tree_flatten, tree_unflatten
13+
14+
from torchft.checkpointing.transport import CheckpointTransport
15+
from torchft.process_group import ProcessGroup
16+
17+
logger: logging.Logger = logging.getLogger(__name__)
18+
19+
T = TypeVar("T")
20+
21+
22+
@dataclass
23+
class _TensorMeta:
24+
shape: torch.Size
25+
dtype: torch.dtype
26+
storage_offset: int
27+
stride: Tuple[int, ...]
28+
nbytes: int
29+
30+
31+
@dataclass
32+
class _DTensorMeta:
33+
local: _TensorMeta
34+
spec: object
35+
36+
37+
@dataclass
38+
class _StateDictMeta:
39+
step: int
40+
spec: object
41+
non_tensors: List[object]
42+
tensor_metas: List[_TensorMeta]
43+
44+
45+
@contextmanager
46+
def _timeit(name: str) -> Generator[None, None, None]:
47+
start = time.perf_counter()
48+
yield
49+
dur = time.perf_counter() - start
50+
logger.info(f"{name} took {dur}s")
51+
52+
53+
def _prepare_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, _TensorMeta]:
54+
return (
55+
_cast_tensor(tensor, torch.uint8),
56+
_TensorMeta(
57+
shape=tensor.shape,
58+
dtype=tensor.dtype,
59+
storage_offset=cast(int, tensor.storage_offset()),
60+
stride=tensor.stride(),
61+
nbytes=tensor.untyped_storage().nbytes(),
62+
),
63+
)
64+
65+
66+
def _prepare_state_dict(
67+
state_dict: object,
68+
step: int,
69+
device: torch.device,
70+
) -> Tuple[_StateDictMeta, List[torch.Tensor]]:
71+
start = time.perf_counter()
72+
values, spec = tree_flatten(state_dict)
73+
74+
non_tensors = []
75+
tensors = []
76+
tensor_metas = []
77+
for v in values:
78+
if isinstance(v, DTensor):
79+
tensor, tensor_meta = _prepare_tensor(v._local_tensor)
80+
81+
tensor_metas.append(tensor_meta)
82+
tensors.append(tensor)
83+
84+
non_tensors.append(
85+
_DTensorMeta(
86+
local=tensor_meta,
87+
spec=v._spec,
88+
)
89+
)
90+
elif isinstance(v, torch.Tensor):
91+
tensor, tensor_meta = _prepare_tensor(v)
92+
tensors.append(tensor)
93+
non_tensors.append(tensor_meta)
94+
tensor_metas.append(tensor_meta)
95+
else:
96+
non_tensors.append(v)
97+
98+
total_size = sum(t.nbytes for t in tensors)
99+
100+
dur = time.perf_counter() - start
101+
logger.info(
102+
f"prepared state_dict {total_size=} {len(tensors)=} {len(non_tensors)=} in {dur}s"
103+
)
104+
105+
return (
106+
_StateDictMeta(
107+
step=step,
108+
spec=spec,
109+
non_tensors=non_tensors,
110+
tensor_metas=tensor_metas,
111+
),
112+
tensors,
113+
)
114+
115+
116+
def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
117+
storage = tensor.untyped_storage()
118+
ret = torch.tensor(storage, dtype=dtype, device=tensor.device)
119+
assert ret.untyped_storage() is storage, "storage should be the same"
120+
return ret
121+
122+
123+
class PGTransport(CheckpointTransport[T]):
124+
"""
125+
This is a checkpoint transport that uses the process group to transfer checkpoints.
126+
This allows for fast recovery of workers by fetching the current weights
127+
from an existing worker.
128+
Args:
129+
state_dict: a callable that returns the state dict to be transferred
130+
"""
131+
132+
def __init__(
133+
self, pg: ProcessGroup, timeout: timedelta, device: torch.device
134+
) -> None:
135+
self._work: List[Work] = []
136+
self._pg = pg
137+
self._timeout = timeout
138+
self._device = device
139+
140+
def metadata(self) -> str:
141+
return "<n/a>"
142+
143+
def disallow_checkpoint(self) -> None:
144+
pass
145+
146+
def send_checkpoint(
147+
self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
148+
) -> None:
149+
meta, tensors = _prepare_state_dict(state_dict, step, device=self._device)
150+
151+
work = []
152+
153+
with _timeit("send pickle"):
154+
buf = pickle.dumps(meta)
155+
len_t = torch.tensor([len(buf)], dtype=torch.int64, device=self._device)
156+
buf_t = torch.frombuffer(buf, dtype=torch.uint8).to(self._device)
157+
for dst_rank in dst_ranks:
158+
work.append(self._pg.send([len_t], dst_rank, tag=1))
159+
work.append(self._pg.send([buf_t], dst_rank, tag=2))
160+
161+
with _timeit("send tensors"):
162+
for i, t in enumerate(tensors):
163+
t = t.to(self._device)
164+
for dst_rank in dst_ranks:
165+
work.append(self._pg.send([t], dst_rank, tag=3 + i))
166+
167+
# allow 3 concurrent transfers at a time
168+
while len(work) > (3 * len(dst_ranks)):
169+
work.pop(0).wait(timeout)
170+
171+
for w in work:
172+
w.wait(timeout)
173+
174+
def recv_checkpoint(
175+
self, src_rank: int, metadata: str, step: int, timeout: timedelta
176+
) -> T:
177+
len_t = torch.zeros(1, dtype=torch.int64, device=self._device)
178+
self._pg.recv([len_t], src_rank, tag=1).wait(timeout)
179+
length = cast(int, len_t.item())
180+
181+
assert length > 0, f"invalid metadata length {length=}"
182+
183+
buf = torch.empty(length, dtype=torch.uint8, device=self._device)
184+
self._pg.recv([buf], src_rank, tag=2).wait(timeout)
185+
186+
meta = pickle.loads(buf.cpu().numpy().tobytes())
187+
assert meta.step == step
188+
189+
i: int = 0
190+
191+
def recv(v: _TensorMeta) -> torch.Tensor:
192+
nonlocal i
193+
194+
t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device)
195+
# TODO: parallelize receives
196+
self._pg.recv([t], src_rank, tag=3 + i).wait(timeout)
197+
i += 1
198+
199+
# TODO: allow in place receives to avoid having to copy to cpu to
200+
# avoid OOMs
201+
t = t.cpu()
202+
203+
return torch.as_strided(
204+
_cast_tensor(t, v.dtype),
205+
size=v.shape,
206+
stride=v.stride,
207+
storage_offset=v.storage_offset,
208+
)
209+
210+
values = []
211+
for v in meta.non_tensors:
212+
if isinstance(v, _TensorMeta):
213+
values.append(recv(v))
214+
elif isinstance(v, _DTensorMeta):
215+
tensor = recv(v.local)
216+
# pyre-fixme[29]: DTensor is not a function
217+
values.append(DTensor(tensor, v.spec, requires_grad=False))
218+
else:
219+
values.append(v)
220+
221+
return tree_unflatten(values, meta.spec)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from datetime import timedelta
2+
from typing import Dict
3+
from unittest import TestCase, skipUnless
4+
5+
import torch
6+
from torch.distributed import TCPStore
7+
8+
from torchft.checkpointing.pg_transport import PGTransport
9+
from torchft.checkpointing.transport import CheckpointTransport
10+
from torchft.checkpointing.transport_test import run_multi_recovery_test
11+
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
12+
13+
14+
class PGTransportTest(TestCase):
15+
def test_pg_transport_gloo(self) -> None:
16+
store: TCPStore = TCPStore(
17+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
18+
)
19+
device: torch.device = torch.device("cpu")
20+
21+
def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
22+
pg = ProcessGroupGloo()
23+
pg.configure(
24+
store_addr=f"localhost:{store.port}/prefix",
25+
rank=rank,
26+
world_size=world_size,
27+
)
28+
29+
return PGTransport[Dict[str, object]](
30+
pg, timeout=timedelta(seconds=10), device=device
31+
)
32+
33+
run_multi_recovery_test(self, init, device=device)
34+
35+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
36+
@skipUnless(torch.cuda.device_count() >= 3, "need three CUDA devices")
37+
def test_pg_transport_baby_nccl(self) -> None:
38+
store: TCPStore = TCPStore(
39+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
40+
)
41+
device: torch.device = torch.device("cuda")
42+
43+
def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
44+
torch.cuda.set_device(rank)
45+
46+
pg = ProcessGroupBabyNCCL()
47+
pg.configure(
48+
store_addr=f"localhost:{store.port}/prefix",
49+
rank=rank,
50+
world_size=world_size,
51+
)
52+
53+
return PGTransport[Dict[str, object]](
54+
pg, timeout=timedelta(seconds=10), device=device
55+
)
56+
57+
run_multi_recovery_test(self, init, device=device)

0 commit comments

Comments
 (0)