-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
executable file
·47 lines (38 loc) · 1.64 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from autoaugment import TrivialAugmentWide
class MOVES_Dataset(Dataset):
def __init__(self, pfc, people=False, train=False):
self.pfc = pfc
self.people = people
self.transform = transforms.Compose(\
([TrivialAugmentWide()] if train else []) +
[transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
def __len__(self):
return len(self.pfc)
def __getitem__(self, idx):
now_frame, future_frame = self.pfc[idx]
return {
'now': {
'frame': now_frame,
'rgb': self.transform(Image.open(now_frame)),
'flow_n_f': np.load(now_frame.replace('frames', 'fwd_flow') + '.npz', allow_pickle=True, mmap_mode='r')['arr_0'],
'people': np.array(Image.open(now_frame.replace('frames', 'people'))) if self.people else np.zeros((512, 512))
},
'future': {
'frame': future_frame,
'rgb': self.transform(Image.open(future_frame)),
'flow_f_n': np.load(now_frame.replace('frames', 'bck_flow') + '.npz', allow_pickle=True, mmap_mode='r')['arr_0'],
'people': np.array(Image.open(future_frame.replace('frames', 'people'))) if self.people else np.zeros((512, 512))
},
}
if __name__ == "__main__":
pfc = pickle.load(open('paired_frame_cache.pkl', 'rb'))
moves_ds = MOVES_Dataset(pfc=pfc, people=False, train=False)