Skip to content

Commit af82e47

Browse files
committed
refactor
1 parent 4745e7c commit af82e47

File tree

3 files changed

+82
-97
lines changed

3 files changed

+82
-97
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ dataset = VideoFrameDataset(
4040
frames_per_segment=1,
4141
imagefile_template='img_{:05d}.jpg',
4242
transform=None,
43-
random_shift=True,
4443
test_mode=False
4544
)
4645

demo.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99
Ignore this function and look at "main" below.
1010
"""
11-
def plot_video(rows, cols, frame_list, plot_width, plot_height):
11+
def plot_video(rows, cols, frame_list, plot_width, plot_height, title: str):
1212
fig = plt.figure(figsize=(plot_width, plot_height))
1313
grid = ImageGrid(fig, 111, # similar to subplot(111)
1414
nrows_ncols=(rows, cols), # creates 2x2 grid of axes
@@ -19,6 +19,7 @@ def plot_video(rows, cols, frame_list, plot_width, plot_height):
1919
# Iterating over the grid returns the Axes.
2020
ax.imshow(im)
2121
ax.set_title(index)
22+
plt.suptitle(title)
2223
plt.show()
2324

2425
if __name__ == '__main__':
@@ -38,7 +39,6 @@ def plot_video(rows, cols, frame_list, plot_width, plot_height):
3839
annotation_file = os.path.join(videos_root, 'annotations.txt')
3940

4041

41-
4242
""" DEMO 1 WITHOUT IMAGE TRANSFORMS """
4343
dataset = VideoFrameDataset(
4444
root_path=videos_root,
@@ -47,15 +47,15 @@ def plot_video(rows, cols, frame_list, plot_width, plot_height):
4747
frames_per_segment=1,
4848
imagefile_template='img_{:05d}.jpg',
4949
transform=None,
50-
random_shift=True,
5150
test_mode=False
5251
)
5352

5453
sample = dataset[0]
5554
frames = sample[0] # list of PIL images
5655
label = sample[1] # integer label
5756

58-
plot_video(rows=1, cols=5, frame_list=frames, plot_width=15., plot_height=3.)
57+
plot_video(rows=1, cols=5, frame_list=frames, plot_width=15., plot_height=3.,
58+
title='Evenly Sampled Frames, No Video Transform')
5959

6060

6161

@@ -72,15 +72,15 @@ def plot_video(rows, cols, frame_list, plot_width, plot_height):
7272
frames_per_segment=9,
7373
imagefile_template='img_{:05d}.jpg',
7474
transform=None,
75-
random_shift=True,
7675
test_mode=False
7776
)
7877

7978
sample = dataset[3]
8079
frames = sample[0] # list of PIL images
8180
label = sample[1] # integer label
8281

83-
plot_video(rows=3, cols=3, frame_list=frames, plot_width=10., plot_height=5.)
82+
plot_video(rows=3, cols=3, frame_list=frames, plot_width=10., plot_height=5.,
83+
title='Continuous Sampled Frame Clip, No Video Transform')
8484

8585

8686

@@ -103,7 +103,6 @@ def plot_video(rows, cols, frame_list, plot_width, plot_height):
103103
frames_per_segment=1,
104104
imagefile_template='img_{:05d}.jpg',
105105
transform=preprocess,
106-
random_shift=True,
107106
test_mode=False
108107
)
109108

@@ -128,7 +127,8 @@ def denormalize(video_tensor):
128127

129128

130129
frame_tensor = denormalize(frame_tensor)
131-
plot_video(rows=1, cols=5, frame_list=frame_tensor, plot_width=15., plot_height=3.)
130+
plot_video(rows=1, cols=5, frame_list=frame_tensor, plot_width=15., plot_height=3.,
131+
title='Evenly Sampled Frames, + Video Transform')
132132

133133

134134

@@ -179,7 +179,6 @@ def denormalize(video_tensor):
179179
frames_per_segment=1,
180180
imagefile_template='img_{:05d}.jpg',
181181
transform=preprocess,
182-
random_shift=True,
183182
test_mode=False
184183
)
185184

video_dataset.py

Lines changed: 74 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from PIL import Image
55
from torchvision import transforms
66
import torch
7+
from typing import List, Union, Tuple, Any
8+
79

810
class VideoRecord(object):
911
"""
@@ -26,22 +28,22 @@ def __init__(self, row, root_datapath):
2628

2729

2830
@property
29-
def path(self):
31+
def path(self) -> str:
3032
return self._path
3133

3234
@property
33-
def num_frames(self):
35+
def num_frames(self) -> int:
3436
return self.end_frame - self.start_frame + 1 # +1 because end frame is inclusive
3537
@property
36-
def start_frame(self):
38+
def start_frame(self) -> int:
3739
return int(self._data[1])
3840

3941
@property
40-
def end_frame(self):
42+
def end_frame(self) -> int:
4143
return int(self._data[2])
4244

4345
@property
44-
def label(self):
46+
def label(self) -> Union[int, List[int]]:
4547
# just one label_id
4648
if len(self._data) == 4:
4749
return int(self._data[3])
@@ -119,7 +121,6 @@ def __init__(self,
119121
frames_per_segment: int = 1,
120122
imagefile_template: str='img_{:05d}.jpg',
121123
transform = None,
122-
random_shift: bool = True,
123124
test_mode: bool = False):
124125
super(VideoFrameDataset, self).__init__()
125126

@@ -129,26 +130,32 @@ def __init__(self,
129130
self.frames_per_segment = frames_per_segment
130131
self.imagefile_template = imagefile_template
131132
self.transform = transform
132-
self.random_shift = random_shift
133133
self.test_mode = test_mode
134134

135-
self._parse_list()
135+
self._parse_annotationfile()
136136
self._sanity_check_samples()
137137

138-
def _load_image(self, directory, idx):
139-
return [Image.open(os.path.join(directory, self.imagefile_template.format(idx))).convert('RGB')]
138+
def _load_image(self, directory: str, idx: int) -> Image.Image:
139+
return Image.open(os.path.join(directory, self.imagefile_template.format(idx))).convert('RGB')
140140

141-
def _parse_list(self):
141+
def _parse_annotationfile(self):
142142
self.video_list = [VideoRecord(x.strip().split(), self.root_path) for x in open(self.annotationfile_path)]
143143

144144
def _sanity_check_samples(self):
145145
for record in self.video_list:
146146
if record.num_frames <= 0 or record.start_frame == record.end_frame:
147-
print(f"\nDataset Warning: data sample {record.path} seems to have zero RGB frames on disk!\n")
147+
print(f"\nDataset Warning: video {record.path} seems to have zero RGB frames on disk!\n")
148+
149+
elif record.num_frames < (self.num_segments * self.frames_per_segment):
150+
print(f"\nDataset Warning: video {record.path} has {record.num_frames} frames "
151+
f"but the dataloader is set up to load "
152+
f"(num_segments={self.num_segments})*(frames_per_segment={self.frames_per_segment})"
153+
f"={self.num_segments * self.frames_per_segment} frames. Dataloader will throw an "
154+
f"error when trying to load this video.\n")
148155

149-
def _sample_indices(self, record):
156+
def _get_start_indices(self, record: VideoRecord) -> 'np.ndarray[int]':
150157
"""
151-
For each segment, chooses an index from where frames
158+
For each segment, choose a start index from where frames
152159
are to be loaded from.
153160
154161
Args:
@@ -157,104 +164,83 @@ def _sample_indices(self, record):
157164
List of indices of where the frames of each
158165
segment are to be loaded from.
159166
"""
167+
# choose start indices that are perfectly evenly spread across the video frames.
168+
if self.test_mode:
169+
distance_between_indices = (record.num_frames - self.frames_per_segment + 1) / float(self.num_segments)
160170

161-
segment_duration = (record.num_frames - self.frames_per_segment + 1) // self.num_segments
162-
if segment_duration > 0:
163-
offsets = np.multiply(list(range(self.num_segments)), segment_duration) + np.random.randint(segment_duration, size=self.num_segments)
164-
165-
# edge cases for when a video has approximately less than (num_frames*frames_per_segment) frames.
166-
# random sampling in that case, which will lead to repeated frames.
167-
else:
168-
offsets = np.sort(np.random.randint(record.num_frames, size=self.num_segments))
169-
170-
return offsets
171-
172-
def _get_val_indices(self, record):
173-
"""
174-
For each segment, finds the center frame index.
175-
176-
Args:
177-
record: VideoRecord denoting a video sample.
178-
Returns:
179-
List of indices of segment center frames.
180-
"""
181-
if record.num_frames > self.num_segments + self.frames_per_segment - 1:
182-
offsets = self._get_test_indices(record)
183-
184-
# edge case for when a video does not have enough frames
171+
start_indices = np.array([int(distance_between_indices / 2.0 + distance_between_indices * x)
172+
for x in range(self.num_segments)])
173+
# randomly sample start indices that are approximately evenly spread across the video frames.
185174
else:
186-
offsets = np.sort(np.random.randint(record.num_frames, size=self.num_segments))
175+
max_valid_start_index = (record.num_frames - self.frames_per_segment + 1) // self.num_segments
187176

188-
return offsets
177+
start_indices = np.multiply(list(range(self.num_segments)), max_valid_start_index) + \
178+
np.random.randint(max_valid_start_index, size=self.num_segments)
189179

190-
def _get_test_indices(self, record):
191-
"""
192-
For each segment, finds the center frame index.
180+
return start_indices
193181

194-
Args:
195-
record: VideoRecord denoting a video sample
196-
Returns:
197-
List of indices of segment center frames.
182+
def __getitem__(self, idx: int) -> Union[
183+
Tuple[List[Image.Image], Union[int, List[int]]],
184+
Tuple['torch.Tensor[num_frames, channels, height, width]', Union[int, List[int]]],
185+
Tuple[Any, Union[int, List[int]]],
186+
]:
198187
"""
199-
200-
tick = (record.num_frames - self.frames_per_segment + 1) / float(self.num_segments)
201-
202-
offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
203-
204-
return offsets
205-
206-
def __getitem__(self, index):
207-
"""
208-
For video with id index, loads self.NUM_SEGMENTS * self.FRAMES_PER_SEGMENT
209-
frames from evenly chosen locations.
188+
For video with id idx, loads self.NUM_SEGMENTS * self.FRAMES_PER_SEGMENT
189+
frames from evenly chosen locations across the video.
210190
211191
Args:
212-
index: Video sample index.
192+
idx: Video sample index.
213193
Returns:
214-
a list of PIL images or the result
215-
of applying self.transform on this list if
216-
self.transform is not None.
194+
A tuple of (video, label). Label is either a single
195+
integer or a list of integers in the case of multiple labels.
196+
Video is either 1) a list of PIL images if no transform is used
197+
2) a batch of shape (NUM_IMAGES x CHANNELS x HEIGHT x WIDTH) in the range [0,1]
198+
if the transform "ImglistToTensor" is used
199+
3) or anything else if a custom transform is used.
217200
"""
218-
record = self.video_list[index]
201+
record: VideoRecord = self.video_list[idx]
219202

220-
if not self.test_mode:
221-
segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
222-
else:
223-
segment_indices = self._get_test_indices(record)
203+
frame_start_indices: 'np.ndarray[int]' = self._get_start_indices(record)
224204

225-
return self._get(record, segment_indices)
205+
return self._get(record, frame_start_indices)
226206

227-
def _get(self, record, indices):
207+
def _get(self, record: VideoRecord, frame_start_indices: 'np.ndarray[int]') -> Union[
208+
Tuple[List[Image.Image], Union[int, List[int]]],
209+
Tuple['torch.Tensor[num_frames, channels, height, width]', Union[int, List[int]]],
210+
Tuple[Any, Union[int, List[int]]],
211+
]:
228212
"""
229213
Loads the frames of a video at the corresponding
230214
indices.
231215
232216
Args:
233217
record: VideoRecord denoting a video sample.
234-
indices: Indices at which to load video frames from.
218+
frame_start_indices: Indices from which to load consecutive frames from.
235219
Returns:
236-
1) A list of PIL images or the result
237-
of applying self.transform on this list if
238-
self.transform is not None.
239-
2) An integer denoting the video label.
220+
A tuple of (video, label). Label is either a single
221+
integer or a list of integers in the case of multiple labels.
222+
Video is either 1) a list of PIL images if no transform is used
223+
2) a batch of shape (NUM_IMAGES x CHANNELS x HEIGHT x WIDTH) in the range [0,1]
224+
if the transform "ImglistToTensor" is used
225+
3) or anything else if a custom transform is used.
240226
"""
241227

242-
indices = indices + record.start_frame
228+
frame_start_indices = frame_start_indices + record.start_frame
243229
images = list()
244-
image_indices = list()
245-
for seg_ind in indices:
246-
frame_index = int(seg_ind)
247-
for i in range(self.frames_per_segment):
248-
seg_img = self._load_image(record.path, frame_index)
249-
images.extend(seg_img)
250-
image_indices.append(frame_index)
230+
231+
# from each start_index, load self.frames_per_segment
232+
# consecutive frames
233+
for start_index in frame_start_indices:
234+
frame_index = int(start_index)
235+
236+
# load self.frames_per_segment consecutive frames
237+
for _ in range(self.frames_per_segment):
238+
image = self._load_image(record.path, frame_index)
239+
images.append(image)
240+
251241
if frame_index < record.end_frame:
252242
frame_index += 1
253243

254-
# sort images by index in case of edge cases where segments overlap each other because the overall
255-
# video is too short for num_segments*frames_per_segment indices.
256-
# _, images = (list(sorted_list) for sorted_list in zip(*sorted(zip(image_indices, images))))
257-
258244
if self.transform is not None:
259245
images = self.transform(images)
260246

@@ -269,7 +255,8 @@ class ImglistToTensor(torch.nn.Module):
269255
of shape (NUM_IMAGES x CHANNELS x HEIGHT x WIDTH) in the range [0,1].
270256
Can be used as first transform for ``VideoFrameDataset``.
271257
"""
272-
def forward(self, img_list):
258+
@staticmethod
259+
def forward(img_list: List[Image.Image]) -> 'torch.Tensor[NUM_IMAGES, CHANNELS, HEIGHT, WIDTH]':
273260
"""
274261
Converts each PIL image in a list to
275262
a torch Tensor and stacks them into

0 commit comments

Comments
 (0)