-
Notifications
You must be signed in to change notification settings - Fork 23
/
dataloader.py
47 lines (42 loc) · 1.43 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import math
import torch
import numpy as np
class BatchDataloader:
def __init__(self, *tensors, bs=1, mask=None):
nonzero_idx, = np.nonzero(mask)
self.tensors = tensors
self.batch_size = bs
self.mask = mask
if nonzero_idx.size > 0:
self.start_idx = min(nonzero_idx)
self.end_idx = max(nonzero_idx)+1
else:
self.start_idx = 0
self.end_idx = 0
def __next__(self):
if self.start == self.end_idx:
raise StopIteration
end = min(self.start + self.batch_size, self.end_idx)
batch_mask = self.mask[self.start:end]
while sum(batch_mask) == 0:
self.start = end
end = min(self.start + self.batch_size, self.end_idx)
batch_mask = self.mask[self.start:end]
batch = [np.array(t[self.start:end]) for t in self.tensors]
self.start = end
self.sum += sum(batch_mask)
return [torch.tensor(b[batch_mask], dtype=torch.float32) for b in batch]
def __iter__(self):
self.start = self.start_idx
self.sum = 0
return self
def __len__(self):
count = 0
start = self.start_idx
while start != self.end_idx:
end = min(start + self.batch_size, self.end_idx)
batch_mask = self.mask[start:end]
if sum(batch_mask) != 0:
count += 1
start = end
return count