1
+ import math
2
+ import pickle
3
+
4
+ import torch
5
+ from torch import distributed as dist
6
+ from torch .utils .data .sampler import Sampler
7
+
8
+
9
+ def get_rank ():
10
+ if not dist .is_available ():
11
+ return 0
12
+
13
+ if not dist .is_initialized ():
14
+ return 0
15
+
16
+ return dist .get_rank ()
17
+
18
+
19
+ def synchronize ():
20
+ if not dist .is_available ():
21
+ return
22
+
23
+ if not dist .is_initialized ():
24
+ return
25
+
26
+ world_size = dist .get_world_size ()
27
+
28
+ if world_size == 1 :
29
+ return
30
+
31
+ dist .barrier ()
32
+
33
+
34
+ def get_world_size ():
35
+ if not dist .is_available ():
36
+ return 1
37
+
38
+ if not dist .is_initialized ():
39
+ return 1
40
+
41
+ return dist .get_world_size ()
42
+
43
+
44
+ def reduce_sum (tensor ):
45
+ if not dist .is_available ():
46
+ return tensor
47
+
48
+ if not dist .is_initialized ():
49
+ return tensor
50
+
51
+ tensor = tensor .clone ()
52
+ dist .all_reduce (tensor , op = dist .ReduceOp .SUM )
53
+
54
+ return tensor
55
+
56
+
57
+ def all_gather (data ):
58
+ world_size = get_world_size ()
59
+
60
+ if world_size == 1 :
61
+ return [data ]
62
+
63
+ buffer = pickle .dumps (data )
64
+ storage = torch .ByteStorage .from_buffer (buffer )
65
+ tensor = torch .ByteTensor (storage ).to ('cuda' )
66
+
67
+ local_size = torch .IntTensor ([tensor .numel ()]).to ('cuda' )
68
+ size_list = [torch .IntTensor ([0 ]).to ('cuda' ) for _ in range (world_size )]
69
+ dist .all_gather (size_list , local_size )
70
+ size_list = [int (size .item ()) for size in size_list ]
71
+ max_size = max (size_list )
72
+
73
+ tensor_list = []
74
+ for _ in size_list :
75
+ tensor_list .append (torch .ByteTensor (size = (max_size ,)).to ('cuda' ))
76
+
77
+ if local_size != max_size :
78
+ padding = torch .ByteTensor (size = (max_size - local_size ,)).to ('cuda' )
79
+ tensor = torch .cat ((tensor , padding ), 0 )
80
+
81
+ dist .all_gather (tensor_list , tensor )
82
+
83
+ data_list = []
84
+
85
+ for size , tensor in zip (size_list , tensor_list ):
86
+ buffer = tensor .cpu ().numpy ().tobytes ()[:size ]
87
+ data_list .append (pickle .loads (buffer ))
88
+
89
+ return data_list
90
+
91
+
92
+ def reduce_loss_dict (loss_dict ):
93
+ world_size = get_world_size ()
94
+
95
+ if world_size < 2 :
96
+ return loss_dict
97
+
98
+ with torch .no_grad ():
99
+ keys = []
100
+ losses = []
101
+
102
+ for k in sorted (loss_dict .keys ()):
103
+ keys .append (k )
104
+ losses .append (loss_dict [k ])
105
+
106
+ losses = torch .stack (losses , 0 )
107
+ dist .reduce (losses , dst = 0 )
108
+
109
+ if dist .get_rank () == 0 :
110
+ losses /= world_size
111
+
112
+ reduced_losses = {k : v for k , v in zip (keys , losses )}
113
+
114
+ return reduced_losses
0 commit comments