-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathutils.py
152 lines (136 loc) · 6.26 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import torch.cuda.comm as comm
def get_class_split(num_classes, num_gpus):
class_split = []
for i in range(num_gpus):
_class_num = num_classes // num_gpus
if i < (num_classes % num_gpus):
_class_num += 1
class_split.append(_class_num)
return class_split
def get_onehot_label(labels, num_gpus, num_classes, model_parallel=False, class_split=None):
# Get one-hot labels
labels = labels.view(-1, 1)
labels_onehot = torch.zeros(len(labels), num_classes).cuda()
labels_onehot.scatter_(1, labels, 1)
if not model_parallel:
return labels_onehot
else:
label_tuple = comm.scatter(labels_onehot, range(num_gpus), class_split, dim=1)
return label_tuple
def get_sparse_onehot_label(labels, num_gpus, num_classes, model_parallel=False, class_split=None):
'''Given number of classes and number of gpus, get the one-hot label sparse tensor
If `model_parallel` is True, we need to split the one-hot label sparse tensor
to each gpu accroding to the `class_split` parameter
Args:
labels (torch.Tensor): generated by pytorch dataloader, its shape is (batch_size,)
num_gpus (int): number of gpus
num_classes (int): number of total classes
model_parallel (bool, optional): if use model parallel
class_split (list, optional): list of ints, its length must equals to `num_gpus`
Returns:
torch.sparse.LongTensor
or tuple of torch.sparse.LongTensor if `model_parallel` is True
'''
labels_list = labels.tolist()
batch_size = len(labels_list)
if not model_parallel:
sparse_index = torch.LongTensor([*range(batch_size), labels_list])
sparse_value = torch.ones(batch_size, dtype=torch.long)
labels_onehot = torch.sparse.LongTensor(sparse_index, sparse_value, torch.Size([batch_size, num_classes]))
return labels_onehot
else:
assert num_gpus == len(class_split), "number of class splits NOT equals to number of gpus!"
# prepare dict for generating sparse tensor
splits_dict = {}
start_index = 0
for i, num_splits in enumerate(class_split):
end_index = start_index + num_splits
splits_dict[i] = {
"start_index": start_index,
"end_index": end_index,
"num_splits": num_splits,
"index_list": [],
"nums": 0
}
start_index = end_index
# get valid index in each split
for i, label in enumerate(labels_list):
for j in range(num_gpus):
if label >= splits_dict[j]["start_index"] and label < splits_dict[j]["end_index"]:
valid_index = [i, label - splits_dict[j]["start_index"]]
splits_dict[j]["index_list"].append(valid_index)
splits_dict[j]["nums"] += 1
break
# finally get the sparse tensor
label_tuple = []
for i in range(num_gpus):
if splits_dict[i]["nums"] == 0:
sparse_tensor = torch.sparse.LongTensor(torch.Size([batch_size, splits_dict[i]["num_splits"]]))
label_tuple.append(sparse_tensor.to(i))
else:
sparse_index = torch.LongTensor(splits_dict[i]["index_list"])
sparse_value = torch.ones(splits_dict[i]["nums"], dtype=torch.long)
sparse_tensor = torch.sparse.LongTensor(
sparse_index.t(),
sparse_value,
torch.Size([batch_size, splits_dict[i]["num_splits"]])
)
label_tuple.append(sparse_tensor.to(i))
return tuple(label_tuple)
def compute_batch_acc(outputs, labels, batch_size, model_parallel, step):
'''compute the batch accuracy accroding to the predictions and groud-truth labels
the complex case here is when `model_parallel` is True, the predictions logits is
located in different gpus, if we don't want to concat them to increase gpu memory,
we need to collect max value of it one by one
Args:
outputs (torch.Tensor or list of torch.Tensor): if `model_parallel` is false,
the `outputs` is a single torch.Tensor, if `model_parallel` is True, outputs
is a tuple of torch.Tensor which located on different gpus
labels (torch.Tensor): generated by pytorch dataloader
batch_size (int): batch size
model_parallel (bool): model parallel flag
step (int): training step in each iteration
Returns:
accuracy (float)
'''
if model_parallel:
if not (step > 0 and step % 10 == 0):
return 0
outputs = [outputs]
max_score = None
max_preds = None
base = 0
for logit_same_tuple in zip(*outputs):
_split = logit_same_tuple[0].size()[1]
score, preds = torch.max(sum(logit_same_tuple).data, dim=1)
score = score.to(0)
preds = preds.to(0)
if max_score is not None:
cond = score > max_score
max_preds = torch.where(cond, preds + base, max_preds)
max_score = torch.where(cond, score, max_score)
else:
max_score = score
max_preds = preds
base += _split
preds = max_preds
batch_acc = torch.sum(preds == labels).item() / batch_size
else:
_, preds = torch.max(outputs.data, 1)
batch_acc = torch.sum(preds == labels).item() / batch_size
return batch_acc
if __name__ == "__main__":
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4"
# labels = torch.tensor([5, 2, 3, 4, 6, 9, 7, 1]).cuda()
# label_tuple = get_onehot_label(labels, 4, 12, [3, 3, 3, 3])
# for label in label_tuple:
# print(label.size())
# print(label)
labels = torch.tensor([5, 2, 3, 4, 6, 9, 7, 1])
print(labels)
label_tuple = get_sparse_onehot_label(labels, 4, 12, True, [3, 3, 3, 3])
for label in label_tuple:
print(label.size())
print(label.to_dense())