-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
Copy pathdistributed_fused_adam.py
3596 lines (3190 loc) · 143 KB
/
distributed_fused_adam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import collections
import contextlib
from dataclasses import dataclass
import enum
import inspect
import io
import itertools
import threading
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
import warnings
import torch
from torch.distributed.distributed_c10d import _get_default_group
try:
import apex.contrib.nccl_allocator as nccl_allocator
except ImportError:
nccl_allocator = None
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
import distributed_adam_cuda
# Fallback to private functions if using PyTorch <1.13.0
try:
from torch.distributed.distributed_c10d import get_global_rank
except ImportError:
from torch.distributed.distributed_c10d import _get_global_rank
get_global_rank = _get_global_rank
try:
from torch.distributed.distributed_c10d import reduce_scatter_tensor
except ImportError:
from torch.distributed.distributed_c10d import _reduce_scatter_base
reduce_scatter_tensor = _reduce_scatter_base
try:
from torch.distributed.distributed_c10d import all_gather_into_tensor
except ImportError:
from torch.distributed.distributed_c10d import _all_gather_base
all_gather_into_tensor = _all_gather_base
# Import context manager to coalesce NCCL calls
# Note: Replace these backward compatibility shims once PyTorch
# exposes a stable public API for coalescing communication.
from torch.distributed.distributed_c10d import _coalescing_manager
if "device" not in inspect.signature(_coalescing_manager).parameters:
# PyTorch <=1.13.1 does not have device arg
_coalescing_manager_no_device_arg = _coalescing_manager
@contextlib.contextmanager
def _coalescing_manager(group, device, reqs):
with _coalescing_manager_no_device_arg(group, reqs):
yield
if "reqs" in inspect.signature(_coalescing_manager).parameters:
# PyTorch <=2.0.1 handles synchronization externally to coalescing
# manager
_coalescing_manager_with_reqs_arg = _coalescing_manager
class _CoalescingManager:
def __init__(self):
self.works: List[torch.distributed.Work] = []
def append(self, work: torch.distributed.Work) -> None:
if work:
self.works.append(work)
def wait(self) -> None:
for work in self.works:
work.wait()
@contextlib.contextmanager
def _coalescing_manager(
group: Optional[torch.distributed.ProcessGroup] = None,
device: Optional[torch.device] = None,
async_ops: bool = False,
) -> contextlib.AbstractContextManager:
assert device is not None
cm = _CoalescingManager()
with _coalescing_manager_with_reqs_arg(
group,
device,
cm.works,
):
yield cm
if not async_ops:
cm.wait()
def _coalescing_manager_append_work(
cm: _CoalescingManager,
work: torch.distributed.Work,
) -> None:
"""Add asynchronous request to coalescing manager"""
cm.append(work)
else:
# PyTorch >2.0.1 handles synchronization within coalescing
# manager
def _coalescing_manager_append_work(
cm: torch.distributed._CoalescingManager,
work: torch.distributed.Work,
) -> None:
"""Dummy function for backward compatibility
Coalescing manager already keeps track of asynchronous
communication.
"""
pass
# Import optional CUDA kernels
_FOUND_DEPRECATED_FUSED_ADAM: bool = False
try:
import fused_adam_cuda
_FOUND_DEPRECATED_FUSED_ADAM = True
except ImportError:
warnings.warn(
"Could not find recommended CUDA kernels when importing "
"`DistributedFusedAdam`. "
"For best performance, Apex should be installed with "
"`--deprecated_fused_adam`."
)
def _round_to_multiple(
number: int,
multiple: int,
round_up: bool = True,
) -> int:
"""Assumes arguments are positive integers"""
return (number + multiple - 1 if round_up else number) // multiple * multiple
def _devices_match(device1: torch.device, device2: torch.device) -> bool:
"""Whether two PyTorch devices are equivalent"""
device1 = torch.device(device1)
device2 = torch.device(device2)
if device1.type != device2.type:
return False
if device1.type == "cuda":
index1 = device1.index
index2 = device2.index
if index1 is None:
index1 = torch.cuda.current_device()
if index2 is None:
index2 = torch.cuda.current_device()
if index1 != index2:
return False
return True
def _multi_tensor_copy(
buffers_in: List[torch.Tensor],
buffers_out: List[torch.Tensor],
dummy_overflow_buf: Optional[torch.Tensor] = None,
) -> None:
"""Copy between corresponding buffers
Uses fused copy kernel if possible.
"""
# Group buffers by device and dtype
buffer_groups = collections.defaultdict(list)
for buf_in, buf_out in zip(buffers_in, buffers_out):
if buf_in.data_ptr() == buf_out.data_ptr() or buf_in.numel() == 0:
# Nothing to be done if input and output buffers are same
# or have no entries
continue
if buf_in.dtype == buf_out.dtype:
# Just copy bytes if dtypes are same
buf_in = buf_in.view(torch.uint8)
buf_out = buf_out.view(torch.uint8)
is_cuda = (
_devices_match(buf_in.device, "cuda")
and _devices_match(buf_out.device, "cuda")
)
is_contiguous = buf_in.is_contiguous() and buf_out.is_contiguous()
key = (
buf_in.dtype,
buf_out.dtype,
is_cuda,
is_contiguous,
)
buffer_groups[key].append((buf_in, buf_out))
# Copy each group of buffers
for key, buffers in buffer_groups.items():
# Check if buffers support fused kernel
dtype_in, dtype_out, is_cuda, is_contiguous = key
supported_dtypes = (torch.float32, torch.float16)
use_fused_kernel = (
dtype_in in supported_dtypes and dtype_out in supported_dtypes
) or (dtype_in == torch.uint8 and dtype_out == torch.uint8)
use_fused_kernel = use_fused_kernel and is_cuda and is_contiguous
# Copy buffers
if use_fused_kernel and _FOUND_DEPRECATED_FUSED_ADAM:
if dummy_overflow_buf is None:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device="cuda")
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
dummy_overflow_buf,
list(zip(*buffers)),
)
else:
# Warning: dummy_overflow_buf was not set in such case
for buf_in, buf_out in buffers:
buf_out.copy_(buf_in)
@contextlib.contextmanager
def _disable_pre_forward_hook(
param: torch.nn.Parameter,
) -> contextlib.AbstractContextManager:
"""Prevent parameter from calling pre-forward hook"""
hook_is_enabled = getattr(
param,
"_pre_forward_hook_is_enabled",
False,
)
if hook_is_enabled:
param._pre_forward_hook_is_enabled = False
try:
yield
finally:
if hook_is_enabled:
param._pre_forward_hook_is_enabled = True
@torch.no_grad()
def _bf16_rem_to_fp32(
bf16: torch.Tensor,
rem: torch.Tensor,
fp32: torch.Tensor,
) -> None:
"""Pack BF16 tensor and 16-bit remainders into FP32 tensor"""
# Check inputs
assert bf16.size() == rem.size() == fp32.size(), (
"Tensor dimensions do not match: "
f"bf16={list(bf16.size())}, "
f"rem={list(rem.size())}, "
f"fp32={list(fp32.size())}, "
)
assert bf16.dtype is torch.bfloat16, f"bf16 buffer has invalid dtype ({bf16.dtype})"
assert rem.dtype is torch.int16, f"rem buffer has invalid dtype ({rem.dtype})"
assert fp32.dtype is torch.float32, f"fp32 buffer has invalid dtype ({fp32.dtype})"
# Undo bf16 rounding
bf16 = bf16.view(torch.int16) - torch.where(rem < 0, 1, 0)
# Pack bf16 and remainder into little-endian fp32
fp32 = fp32.unsqueeze(-1).view(torch.int16)
fp32 = torch.stack((rem, bf16), dim=-1, out=fp32)
class DistributedFusedAdam(torch.optim.Optimizer):
"""Adam optimizer with ZeRO algorithm.
Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext --distributed_adam --deprecated_fused_adam``.
This implements the ZeRO-2 algorithm, which distributes the
optimizer state and gradients between parallel processes. In
particular, the parameters are flattened, grouped into fixed-size
buckets, and the optimizer state for each bucket is sharded over
the parallel processes. Options are provided to overlap the
gradient synchronization with the backward pass compute.
Adam was proposed in `Adam: A Method for Stochastic
Optimization`_, AdamW in `Decoupled Weight Decay Regularization`_,
and ZeRO in `ZeRO: Memory Optimizations Toward Training Trillion
Parameter Models`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts
defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
bias_correction (bool, optional): apply correction factor to
moment estimates. (default: True)
betas (Tuple[float, float], optional): coefficients used for
computing running averages of gradient and its square.
(default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
adam_w_mode (boolean, optional): Decouple weight decay
regularization (also known as AdamW algorithm) (default:
True)
weight_decay (float, optional): weight decay (L2 penalty)
(default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad
variant of this algorithm from the paper
`On the Convergence of Adam and Beyond`_ (default: False).
This is not yet supported.
dtype (torch.dtype, optional): datatype for optimizer state
(default: torch.float32)
grad_sync_dtype (torch.dtype, optional): datatype for gradient
synchronization (default: same as dtype)
param_sync_dtype (torch.dtype, optional): datatype for
parameter synchronization (default: same as dtype)
device (torch.device, optional): device for optimizer state
(default: cuda). Currently only supports GPU with one GPU
per process.
process_group (torch.distributed.ProcessGroup, optional):
parallel processes participating in optimizer (default:
default group in torch.distributed). This group is
interpreted as a 2D grid with dimensions
distributed_size x redundant_size.
distributed_process_group (torch.distributed.ProcessGroup,
optional): parallel processes to distribute optimizer
state over (default: same as process_group)
redundant_process_group (torch.distributed.ProcessGroup,
optional): parallel processes to replicate optimizer state
over (default: group only containing calling process)
average_grad_sync (bool, optional): whether to use average
reduction for gradient synchronization rather than sum
(default: True)
overlap_grad_sync (boolean, optional): whether to overlap
gradient synchronization with backward pass compute
(default: True)
overlap_param_sync (boolean, optional): whether to overlap
parameter synchronization with forward pass compute
(default: False). This is an experimental feature.
bucket_cap_mb (float, optional): bucket size in megabytes
(default: 100)
pipeline_size (int, optional): number of buckets to process
simultaneously in optimizer step (default: 2)
contiguous_param_buffer (bool, optional): convert parameters
into views into large persistent buffers (default: False).
This enables some performance optimizations (e.g. avoiding
some memory copies), but may add memory overhead (e.g. if
the memory allocator can't reuse the original parameter
buffers).
contiguous_grad_buffer (bool, optional): allocate gradient
buckets out of a large persistent buffers (default:
False). This allows individual parameter gradients to be
accessed externally (see grad_buffer_view function). It
enables some performance optimizations (e.g. avoiding some
memory copies), but prevents some memory optimizations
(e.g. the memory allocator can't reuse buffers for
gradient buckets).
store_params (bool, optional): store a distributed copy of the
parameters as optimizer state (default: True). This may be
desirable if the optimizer dtype has higher precision than
the parameter dtype.
store_param_remainders (bool, optional): if model is BF16 and
optimizer is FP32, store bits required to reconstruct FP32
params (default: False). This is an experimental feature.
with_scaled_states (bool, optional): apply per-tensor scaling
factors to the optimizer state (default: False). As
discussed in `FP8-LM: Training FP8 Large Language
Models`_, this helps maintain a reasonable dynamic range
even when the state is in a low-precision datatype like
FP16.
nccl_ub (bool, optional): enable NCCL user buffers for zero-copy
(default: False). It allows the collectives to use only 1 SM
when IB SHARP is enabled in a one-rank-per-node communication
group. This will help speedup the gemms overlapped with data-
parallel communications.
capturable (bool, optional): whether to use the version of the
optimizer that can be used with CUDA Graphs. (default: False).
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
.. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
.. _ZeRO\: Memory Optimizations Toward Training Trillion Parameter Models:
https://arxiv.org/abs/1910.02054
.. _FP8-LM\: Training FP8 Large Language Models:
https://arxiv.org/pdf/2310.18313v2.pdf
"""
@dataclass
class ParameterFragment:
"""Buffer ranges for a parameter fragment
Describes corresponding regions in parameter buffer and
parameter bucket.
"""
# Parameter group index
param_group_id: int
# Parameter index within parameter group
param_id: int
# Bucket index
bucket_id: int
# Range within flattened parameter buffer
param_range: Tuple[int, int]
# Range within bucket
bucket_range: Tuple[int, int]
# Whether fragment is in local shard of bucket
in_local_shard: bool
# Range within local shard
shard_range: Optional[Tuple[int, int]]
# Range of local fragment shard within bucket
shard_bucket_range: Optional[Tuple[int, int]]
# Range of local fragment shard within parameter
shard_param_range: Optional[Tuple[int, int]]
class StateBucket:
"""Optimizer state for a bucket"""
def __init__(
self,
bucket_size: int,
shard_size: int,
dtype: torch.dtype,
device: torch.device,
grad_sync_dtype: torch.dtype,
param_sync_dtype: torch.dtype,
contiguous_buffer_offset: int = 0,
store_params: bool = False,
store_param_remainders: bool = False,
):
# Size of parameter bucket
self.bucket_size: int = bucket_size
# Size of local shard of parameter bucket
self.shard_size: int = shard_size
# Data type for state
self.dtype = dtype
# Data type for gradient synchronization
self.grad_sync_dtype = grad_sync_dtype
# Data type for parameter synchronization
self.param_sync_dtype = param_sync_dtype
# Size of the filled region in the bucket
self.filled_size: int = 0
# Is it able to continue filling
self.able_to_fill: bool = True
# Offset to bucket in contiguous buffers
self.contiguous_buffer_offset: int = contiguous_buffer_offset
# Buffer ranges corresponding to parameter fragments
self.fragments: List[ParameterFragment] = []
# Local shard of parameters
self.params_shard: Optional[torch.Tensor] = None
if store_params:
self.params_shard = torch.zeros(
[shard_size],
dtype=self.dtype,
device=device,
)
# Local shard of parameter remainders
self.param_remainders_shard: Optional[torch.Tensor] = None
if store_param_remainders:
self.param_remainders_shard = torch.zeros(
[shard_size],
dtype=torch.int16,
device=device,
)
# Local shard of first moment estimate
self.exp_avg_shard: torch.Tensor = torch.zeros(
[shard_size],
dtype=self.dtype,
device=device,
)
# Local shard of second moment estimate
self.exp_avg_sq_shard: torch.Tensor = torch.zeros(
[shard_size],
dtype=self.dtype,
device=device,
)
def dtypes(self) -> Tuple[torch.dtype, torch.dtype, torch.dtype]:
"""Datatypes for the bucket's compute and communication"""
return (
self.dtype,
self.grad_sync_dtype,
self.param_sync_dtype,
)
class GradientStatus(enum.Enum):
"""Status of gradients within a bucket"""
# Gradients are ready to use
READY = enum.auto()
# Bucket is partially filled with unreduced gradients
PARTIALLY_FILLED = enum.auto()
# Bucket is fully filled with unreduced gradients
FULLY_FILLED = enum.auto()
# Asynchronous reduction is in progress
SYNCING = enum.auto()
class GradientBucket:
"""Gradient buffers and state for a bucket"""
def __init__(self):
# Local shard of gradients
self.grads_shard: Optional[torch.Tensor] = None
# Local contribution to gradients
self.grads_bucket: Optional[torch.Tensor] = None
# Buffer for gradient reduce-scatter
self.sync_grads_shard: Optional[torch.Tensor] = None
# Status of gradients
self.status: GradientStatus = DistributedFusedAdam.GradientStatus.READY
# Params that have generated grads
self.grads_generated: Set[torch.nn.Parameter] = set()
class ParameterStatus(enum.Enum):
"""Status of parameters within a bucket"""
# Parameters are sharded between processes
SHARDED = enum.auto()
# Asynchronous communication is in progress
SYNCING = enum.auto()
# Parameters are ready to use
READY = enum.auto()
class ParameterBucket:
"""Parameter buffers and state for a bucket"""
def __init__(self):
# Local shard of parameters
self.params_shard: Optional[torch.Tensor] = None
# Gathered parameter values
self.params_bucket: Optional[torch.Tensor] = None
# Status of parameters
self.status: ParameterStatus = DistributedFusedAdam.ParameterStatus.SHARDED
# Params that have been updated
self.params_updated: Set[torch.nn.Parameter] = set()
# Enable custom logic for AMP grad scaling
_step_supports_amp_scaling: bool = True
_custom_amp_unscale_grads: bool = True
def __init__(
self,
params: Union[Iterable[torch.nn.Parameter], Iterable[dict]],
lr: float = 1e-3,
bias_correction: bool = True,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
adam_w_mode: bool = True,
weight_decay: float = 0.0,
amsgrad: bool = False,
dtype: torch.dtype = torch.float32,
grad_sync_dtype: Optional[torch.dtype] = None,
param_sync_dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = "cuda",
process_group: Optional[torch.distributed.ProcessGroup] = None,
distributed_process_group: Optional[torch.distributed.ProcessGroup] = None,
redundant_process_group: Optional[torch.distributed.ProcessGroup] = None,
average_grad_sync: bool = True,
overlap_grad_sync: bool = True,
overlap_param_sync: bool = False,
bucket_cap_mb: float = 100.0,
pipeline_size: int = 2,
contiguous_param_buffer: bool = False,
contiguous_grad_buffer: bool = False,
store_params: bool = True,
store_param_remainders: bool = False,
with_scaled_states: bool = False,
nccl_ub: bool = False,
capturable: bool = False,
):
if (with_scaled_states or store_param_remainders) and capturable:
raise Exception(f"{self.__class__.__name__} with scaled states "
"or storing param remainders doesn't support CUDA graph yet.")
if capturable and not _FOUND_DEPRECATED_FUSED_ADAM:
raise Exception(f"Capturable {self.__class__.__name__} relies on "
"multi_tensor_copy to set dummy_overflow_buf to indicate "
"whether there's gradient Inf/NaN, build APEX with "
"`--deprecated_fused_adam` is essential.")
# If capturable for CUDA graph
self.capturable: bool = capturable
# If the optimizer is capturable then LR should be a tensor (on GPU)
if capturable:
lr = torch.tensor(lr, dtype=torch.float32, device=device)
defaults = dict(
lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)
super().__init__(params, defaults)
# Adam options
self.adam_w_mode: bool = adam_w_mode
self.amsgrad: bool = amsgrad
if amsgrad:
raise RuntimeError(
"DistributedFusedAdam does not support the AMSGrad variant."
)
# Datatype options
if grad_sync_dtype is None:
grad_sync_dtype = dtype
if param_sync_dtype is None:
param_sync_dtype = dtype
supported_dtypes = (torch.float32, torch.float16, torch.bfloat16)
if (
dtype not in supported_dtypes
or grad_sync_dtype not in supported_dtypes
):
raise ValueError(
"Unsupported dtypes for DistributedFusedAdam "
f"(dtype={dtype}, "
f"grad_sync_dtype={grad_sync_dtype}, "
f"param_sync_dtype={param_sync_dtype}))"
)
self.dtype: torch.dtype = dtype
self.grad_sync_dtype: torch.dtype = grad_sync_dtype
self.param_sync_dtype: torch.dtype = param_sync_dtype
# Device options
if not _devices_match(device, "cuda"):
raise RuntimeError(
"Invalid device for DistributedFusedAdam " f"(device={device})"
)
self.device: torch.device = torch.device("cuda", torch.cuda.current_device())
# Process groups
self.process_group: torch.distributed.ProcessGroup = (
_get_default_group() if process_group is None else process_group
)
self.distributed_process_group: torch.distributed.ProcessGroup = (
self.process_group
if distributed_process_group is None
else distributed_process_group
)
self.redundant_process_group: Optional[
torch.distributed.ProcessGroup
] = redundant_process_group
self.process_group_size: int = torch.distributed.get_world_size(
self.process_group
)
self.distributed_rank: int = torch.distributed.get_rank(
self.distributed_process_group
)
self.distributed_size: int = torch.distributed.get_world_size(
self.distributed_process_group
)
self.redundant_size: int = (
1
if self.redundant_process_group is None
else torch.distributed.get_world_size(self.redundant_process_group)
)
if self.process_group_size != self.distributed_size * self.redundant_size:
raise RuntimeError(
"Invalid process group configuration "
f"(process group size = {self.process_group_size}, "
f"distributed process group size = {self.distributed_size}, "
f"redundant process group size = {self.redundant_size})"
)
self.process_group_root: int = get_global_rank(self.process_group, 0)
# Use average reduction for grad sync
self.average_grad_sync: bool = average_grad_sync
# Copy param grads to bucket as soon as available
self.greedy_grad_copy: bool = True
# Synchronize grad buckets as soon as their grads are available
self.overlap_grad_sync: bool = overlap_grad_sync
# Try synchronizing param buckets just before param is needed
self.overlap_param_sync: bool = overlap_param_sync
# Number of buckets to synchronize at a time
self.pipeline_size: int = pipeline_size
# Store params or param remainders
if store_param_remainders:
if store_params:
raise RuntimeError(
"Attempted to construct DistributedFusedAdam "
"with store_params=True and store_param_remainders=True"
)
if self.dtype != torch.float32 or self.param_sync_dtype != torch.bfloat16:
raise RuntimeError(
"DistributedFusedAdam requires "
"BF16 params and FP32 optimizer state "
"when storing parameter remainders "
f"(dtype={self.dtype}, "
f"param_sync_dtype={self.param_sync_dtype}))"
)
self.store_params: bool = store_params
self.store_param_remainders: bool = store_param_remainders
# Whether to scale optimizer state
self.with_scaled_states: bool = with_scaled_states
if self.with_scaled_states:
if not self.store_params:
raise RuntimeError(
"Attempted to construct DistributedFusedAdam "
"with with_scaled_state=True and store_params=False"
)
if self.store_param_remainders:
raise RuntimeError(
"Attempted to construct DistributedFusedAdam "
"with with_scaled_state=True and store_params_remainders=True"
)
if self.dtype not in (torch.float16, torch.bfloat16):
raise RuntimeError(
"Attempted to construct DistributedFusedAdam "
f"with with_scaled_state=True and dtype={self.dtype} "
"(only fp16 and bf16 are supported)"
)
if self.param_sync_dtype == torch.float32:
# _local_step_with_scaled_states applies Adam kernel
# to fp32 workspace buffer and relies on
# _check_params_shard_dtypes to copy to param sync
# workspace buffer. However,
# _check_params_shard_dtypes does nothing if
# param_sync_dtype is fp32.
raise RuntimeError(
"Attempted to construct DistributedFusedAdam "
f"with with_scaled_state=True and param_sync_dtype={self.param_sync_dtype}"
)
# Scaling factors to apply to recover unscaled optimizer state
self._state_scales: dict = {}
# Determine bucket sizes
dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8
self.alignment: int = 128 // dtype_size
self.bucket_cap_mb: float = bucket_cap_mb
bucket_size = 1024 * 1024 * bucket_cap_mb / dtype_size
shard_size = int(bucket_size / self.distributed_size)
shard_size = _round_to_multiple(shard_size, self.alignment, round_up=False)
shard_size = max(shard_size, self.alignment)
self.default_shard_size: int = shard_size
# Optimizer state
self.state["buckets"]: List[StateBucket] = []
self.state["step"]: torch.Tensor | int = torch.tensor([0], dtype=torch.int,
device=self.device) if self.capturable else 0
# Gradient state
self._grads_buckets: Dict[int, GradientBucket] = collections.defaultdict(
self.GradientBucket
)
# Param state
self._params_buckets: Dict[int, ParameterBucket] = collections.OrderedDict()
# Whether to allocate contiguous buffers for parameters
self.contiguous_param_buffer: bool = contiguous_param_buffer
# Whether to allocate contiguous buffers for gradients
self.contiguous_grad_buffer: bool = contiguous_grad_buffer
# Whether to use NCCL User Buffer
self.nccl_ub: bool = nccl_ub
# Contiguous buffers for parameters
self._param_buffers: Dict[
Tuple[torch.dtype, torch.dtype, torch.dtype], torch.Tensor
] = {}
# Contiguous buffers for gradients
self._grad_buffers: Dict[
Tuple[torch.dtype, torch.dtype, torch.dtype], torch.Tensor
] = {}
# Output buffer for gradient shards, only required for NCCL user buffer
if self.nccl_ub:
if not nccl_allocator:
raise RuntimeError("NCCL allocator importing failed but nccl ub is still requested")
elif not self.contiguous_grad_buffer:
raise RuntimeError("NCCL user buffers require contiguous grad buffers")
else:
self._shard_grad_buffers: Dict[
Tuple[torch.dtype, torch.dtype, torch.dtype], torch.Tensor
] = {}
# Side streams for state dict communication
self._pipeline_streams: List[torch.cuda.Stream] = [
torch.cuda.Stream() for _ in range(self.pipeline_size)
]
# Side streams for gradients and parameters communication
self._comm_streams: List[torch.cuda.Stream] = [
torch.cuda.Stream() for _ in range(self.pipeline_size)
]
self._last_comm_stream_id: int = -1
# Scale by factor before optimizer step. Used for grad
# clipping and gradient scaler.
self._grad_scale: torch.Tensor = torch.full(
[], 1.0, dtype=torch.float32, device=self.device
)
# Norm of parameter gradients. Used for gradient clipping and
# gradient scaler.
self._grad_norm: Optional[torch.Tensor] = None
# Dummy flag for multi-tensor kernels
# Note: Apex multi-tensor kernels have a noop_flag argument
# that is intended to detect non-finite values. It shouldn't
# have any effect with the kernels used in the optimizer, but
# we still set it to zero out of an abundance of caution.
self._dummy_overflow_buf: torch.Tensor = torch.zeros(
[1], dtype=torch.int32, device=self.device
)
# Check if collectives have no_copy option
self._gather_no_copy: bool = (
"no_copy" in inspect.getfullargspec(torch.distributed.gather).args
)
# Make sure parameter values are same across processes
self._broadcast_params()
# Lock for callbacks
self._lock: threading.Lock = threading.Lock()
# Attach hooks for gradient synchronization
self._register_post_backward_hooks()
# Attach hooks for param synchronization
if self.overlap_param_sync:
self._register_pre_forward_hooks()
# Move LR to device
if capturable:
for idx, group in enumerate(self.param_groups):
if len(group['params']) == 0:
continue
for item in ['lr']:
if torch.is_tensor(group[item]):
self.param_groups[idx][item] = group[item].to(device=self.device)
else:
self.param_groups[idx][item] = torch.tensor(group[item],
device=self.device)
# For better representation string
arg_names = inspect.getfullargspec(DistributedFusedAdam.__init__).args
arg_names.remove('self')
arg_names.remove('params')
for i, group in enumerate(self.param_groups):
for key in sorted(group.keys()):
if key in arg_names:
arg_names.remove(key)
self.args_dict = {name: getattr(self, name) for name in arg_names}
def __repr__(self) -> str:
# Based on: https://github.com/pytorch/pytorch/blob/v2.3.0-rc12/torch/optim/optimizer.py#L315
format_string = self.__class__.__name__ + ' ('
for i, group in enumerate(self.param_groups):
format_string += '\n'
format_string += f'Parameter Group {i}\n'
for key in sorted(group.keys()):
if key != 'params':
format_string += f' {key}: {group[key]}\n'
for key, val in self.args_dict.items():
if 'process_group' in key and val:
format_string += f'{key}: {hex(id(val))}, world size {val.size()}\n'
else:
format_string += f'{key}: {val}\n'
format_string += ')'
return format_string
@torch.no_grad()
def _broadcast_params(self) -> None:
"""Broadcast parameter values from root rank"""
process_group = self.process_group
with _coalescing_manager(process_group, self.device, async_ops=True) as cm:
for param_group in self.param_groups:
for param in param_group["params"]:
_coalescing_manager_append_work(
cm,
torch.distributed.broadcast(
param,
src=self.process_group_root,
group=process_group,
async_op=True,
),
)
cm.wait()
def _make_post_backward_hook(
self,
param: torch.nn.Parameter,
param_group_id: int,
param_id: int,
) -> Callable:
"""Create callback function to call after param generates grad
Lazily initialize parameter and try launching grad sync.
"""
def post_backward_hook(*unused) -> None:
if getattr(param, "_pre_forward_hook_is_enabled", False):
raise RuntimeError(
"A parameter called its post-backward hook "
"before its pre-forward hook. "
"Please manually interact with the parameter "
"before the forward pass (e.g. by calling data_ptr) "
"or run DistributedFusedAdam with overlap_param_sync=False."
)
with self._lock:
need_to_initialize = "fragments" not in self.state[param]
if need_to_initialize:
self._init_param_state(param, param_group_id, param_id)
if self.greedy_grad_copy:
self._grad_copy(param)
if self.overlap_grad_sync:
self._try_start_bucket_grad_sync(
params=[param],
ignore_last_bucket=need_to_initialize,
)
return post_backward_hook
def _register_post_backward_hooks(self) -> None:
"""Attach hooks for gradient synchronization"""
self._grad_accs = []
for param_group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group["params"]):
if param.requires_grad:
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
hook = self._make_post_backward_hook(
param,
param_group_id,
param_id,
)
grad_acc.register_hook(hook)
self._grad_accs.append(grad_acc)
def _make_pre_forward_hook(
self,
param: torch.nn.Parameter,
param_group_id: int,
param_id: int,
) -> Callable:
"""Create callback function to call before param forward pass
Make sure param has been synchronized and try launching next
param sync.
"""
def pre_forward_hook(*unused) -> None:
with self._lock:
if "fragments" not in self.state[param]:
return
self._param_copy(param)
if self.overlap_param_sync:
self._try_start_bucket_param_sync()
return pre_forward_hook
def _register_pre_forward_hooks(self) -> None:
"""Attach hooks for parameter synchronization
If _pre_forward_hook_is_enabled is set in a parameter, then
the callback will be called the first time any of its
attributes are accessed. This is hackily done by
monkey-patching the parameter class, so proceed with caution.
"""
for param_group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group["params"]):
# Monkey-patch parameter class
cls = param.__class__
if not getattr(cls, "_has_pre_forward_hook", False):
# Monkey-patch magic methods to call __getattribute__
special_funcs = [
"__abs__",
"__add__",
"__and__",
"__bool__",
"__complex__",
"__contains__",
"__deepcopy__",
"__delitem__",
"__div__",
"__eq__",
"__float__",
"__floordiv__",
"__ge__",
"__getitem__",
"__gt__",
"__iadd__",
"__iand__",
"__idiv__",
"__ifloordiv__",
"__ilshift__",
"__imod__",
"__imul__",
"__index__",
"__int__",
"__invert__",
"__ior__",
"__ipow__",
"__irshift__",
"__isub__",
"__iter__",
"__itruediv__",