Skip to content

Commit 677b91e

Browse files
committed
Initial implementations
1 parent 4708d72 commit 677b91e

File tree

8 files changed

+1121
-0
lines changed

8 files changed

+1121
-0
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
wandb/
132+
*.lmdb/

checkpoint/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.pt

dataset.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from io import BytesIO
2+
3+
import lmdb
4+
from PIL import Image
5+
from torch.utils.data import Dataset
6+
7+
8+
class MultiResolutionDataset(Dataset):
9+
def __init__(self, path, transform, resolution=256):
10+
self.env = lmdb.open(
11+
path,
12+
max_readers=32,
13+
readonly=True,
14+
lock=False,
15+
readahead=False,
16+
meminit=False,
17+
)
18+
19+
if not self.env:
20+
raise IOError('Cannot open lmdb dataset', path)
21+
22+
with self.env.begin(write=False) as txn:
23+
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
24+
25+
self.resolution = resolution
26+
self.transform = transform
27+
28+
def __len__(self):
29+
return self.length
30+
31+
def __getitem__(self, index):
32+
with self.env.begin(write=False) as txn:
33+
key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
34+
img_bytes = txn.get(key)
35+
36+
buffer = BytesIO(img_bytes)
37+
img = Image.open(buffer)
38+
img = self.transform(img)
39+
40+
return img

distributed.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import math
2+
import pickle
3+
4+
import torch
5+
from torch import distributed as dist
6+
from torch.utils.data.sampler import Sampler
7+
8+
9+
def get_rank():
10+
if not dist.is_available():
11+
return 0
12+
13+
if not dist.is_initialized():
14+
return 0
15+
16+
return dist.get_rank()
17+
18+
19+
def synchronize():
20+
if not dist.is_available():
21+
return
22+
23+
if not dist.is_initialized():
24+
return
25+
26+
world_size = dist.get_world_size()
27+
28+
if world_size == 1:
29+
return
30+
31+
dist.barrier()
32+
33+
34+
def get_world_size():
35+
if not dist.is_available():
36+
return 1
37+
38+
if not dist.is_initialized():
39+
return 1
40+
41+
return dist.get_world_size()
42+
43+
44+
def reduce_sum(tensor):
45+
if not dist.is_available():
46+
return tensor
47+
48+
if not dist.is_initialized():
49+
return tensor
50+
51+
tensor = tensor.clone()
52+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
53+
54+
return tensor
55+
56+
57+
def all_gather(data):
58+
world_size = get_world_size()
59+
60+
if world_size == 1:
61+
return [data]
62+
63+
buffer = pickle.dumps(data)
64+
storage = torch.ByteStorage.from_buffer(buffer)
65+
tensor = torch.ByteTensor(storage).to('cuda')
66+
67+
local_size = torch.IntTensor([tensor.numel()]).to('cuda')
68+
size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
69+
dist.all_gather(size_list, local_size)
70+
size_list = [int(size.item()) for size in size_list]
71+
max_size = max(size_list)
72+
73+
tensor_list = []
74+
for _ in size_list:
75+
tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
76+
77+
if local_size != max_size:
78+
padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
79+
tensor = torch.cat((tensor, padding), 0)
80+
81+
dist.all_gather(tensor_list, tensor)
82+
83+
data_list = []
84+
85+
for size, tensor in zip(size_list, tensor_list):
86+
buffer = tensor.cpu().numpy().tobytes()[:size]
87+
data_list.append(pickle.loads(buffer))
88+
89+
return data_list
90+
91+
92+
def reduce_loss_dict(loss_dict):
93+
world_size = get_world_size()
94+
95+
if world_size < 2:
96+
return loss_dict
97+
98+
with torch.no_grad():
99+
keys = []
100+
losses = []
101+
102+
for k in sorted(loss_dict.keys()):
103+
keys.append(k)
104+
losses.append(loss_dict[k])
105+
106+
losses = torch.stack(losses, 0)
107+
dist.reduce(losses, dst=0)
108+
109+
if dist.get_rank() == 0:
110+
losses /= world_size
111+
112+
reduced_losses = {k: v for k, v in zip(keys, losses)}
113+
114+
return reduced_losses

0 commit comments

Comments
 (0)