44from PIL import Image
55from torchvision import transforms
66import torch
7+ from typing import List , Union , Tuple , Any
8+
79
810class VideoRecord (object ):
911 """
@@ -26,22 +28,22 @@ def __init__(self, row, root_datapath):
2628
2729
2830 @property
29- def path (self ):
31+ def path (self ) -> str :
3032 return self ._path
3133
3234 @property
33- def num_frames (self ):
35+ def num_frames (self ) -> int :
3436 return self .end_frame - self .start_frame + 1 # +1 because end frame is inclusive
3537 @property
36- def start_frame (self ):
38+ def start_frame (self ) -> int :
3739 return int (self ._data [1 ])
3840
3941 @property
40- def end_frame (self ):
42+ def end_frame (self ) -> int :
4143 return int (self ._data [2 ])
4244
4345 @property
44- def label (self ):
46+ def label (self ) -> Union [ int , List [ int ]] :
4547 # just one label_id
4648 if len (self ._data ) == 4 :
4749 return int (self ._data [3 ])
@@ -119,7 +121,6 @@ def __init__(self,
119121 frames_per_segment : int = 1 ,
120122 imagefile_template : str = 'img_{:05d}.jpg' ,
121123 transform = None ,
122- random_shift : bool = True ,
123124 test_mode : bool = False ):
124125 super (VideoFrameDataset , self ).__init__ ()
125126
@@ -129,26 +130,32 @@ def __init__(self,
129130 self .frames_per_segment = frames_per_segment
130131 self .imagefile_template = imagefile_template
131132 self .transform = transform
132- self .random_shift = random_shift
133133 self .test_mode = test_mode
134134
135- self ._parse_list ()
135+ self ._parse_annotationfile ()
136136 self ._sanity_check_samples ()
137137
138- def _load_image (self , directory , idx ) :
139- return [ Image .open (os .path .join (directory , self .imagefile_template .format (idx ))).convert ('RGB' )]
138+ def _load_image (self , directory : str , idx : int ) -> Image . Image :
139+ return Image .open (os .path .join (directory , self .imagefile_template .format (idx ))).convert ('RGB' )
140140
141- def _parse_list (self ):
141+ def _parse_annotationfile (self ):
142142 self .video_list = [VideoRecord (x .strip ().split (), self .root_path ) for x in open (self .annotationfile_path )]
143143
144144 def _sanity_check_samples (self ):
145145 for record in self .video_list :
146146 if record .num_frames <= 0 or record .start_frame == record .end_frame :
147- print (f"\n Dataset Warning: data sample { record .path } seems to have zero RGB frames on disk!\n " )
147+ print (f"\n Dataset Warning: video { record .path } seems to have zero RGB frames on disk!\n " )
148+
149+ elif record .num_frames < (self .num_segments * self .frames_per_segment ):
150+ print (f"\n Dataset Warning: video { record .path } has { record .num_frames } frames "
151+ f"but the dataloader is set up to load "
152+ f"(num_segments={ self .num_segments } )*(frames_per_segment={ self .frames_per_segment } )"
153+ f"={ self .num_segments * self .frames_per_segment } frames. Dataloader will throw an "
154+ f"error when trying to load this video.\n " )
148155
149- def _sample_indices (self , record ) :
156+ def _get_start_indices (self , record : VideoRecord ) -> 'np.ndarray[int]' :
150157 """
151- For each segment, chooses an index from where frames
158+ For each segment, choose a start index from where frames
152159 are to be loaded from.
153160
154161 Args:
@@ -157,104 +164,83 @@ def _sample_indices(self, record):
157164 List of indices of where the frames of each
158165 segment are to be loaded from.
159166 """
167+ # choose start indices that are perfectly evenly spread across the video frames.
168+ if self .test_mode :
169+ distance_between_indices = (record .num_frames - self .frames_per_segment + 1 ) / float (self .num_segments )
160170
161- segment_duration = (record .num_frames - self .frames_per_segment + 1 ) // self .num_segments
162- if segment_duration > 0 :
163- offsets = np .multiply (list (range (self .num_segments )), segment_duration ) + np .random .randint (segment_duration , size = self .num_segments )
164-
165- # edge cases for when a video has approximately less than (num_frames*frames_per_segment) frames.
166- # random sampling in that case, which will lead to repeated frames.
167- else :
168- offsets = np .sort (np .random .randint (record .num_frames , size = self .num_segments ))
169-
170- return offsets
171-
172- def _get_val_indices (self , record ):
173- """
174- For each segment, finds the center frame index.
175-
176- Args:
177- record: VideoRecord denoting a video sample.
178- Returns:
179- List of indices of segment center frames.
180- """
181- if record .num_frames > self .num_segments + self .frames_per_segment - 1 :
182- offsets = self ._get_test_indices (record )
183-
184- # edge case for when a video does not have enough frames
171+ start_indices = np .array ([int (distance_between_indices / 2.0 + distance_between_indices * x )
172+ for x in range (self .num_segments )])
173+ # randomly sample start indices that are approximately evenly spread across the video frames.
185174 else :
186- offsets = np . sort ( np . random . randint ( record .num_frames , size = self .num_segments ))
175+ max_valid_start_index = ( record .num_frames - self .frames_per_segment + 1 ) // self . num_segments
187176
188- return offsets
177+ start_indices = np .multiply (list (range (self .num_segments )), max_valid_start_index ) + \
178+ np .random .randint (max_valid_start_index , size = self .num_segments )
189179
190- def _get_test_indices (self , record ):
191- """
192- For each segment, finds the center frame index.
180+ return start_indices
193181
194- Args:
195- record: VideoRecord denoting a video sample
196- Returns:
197- List of indices of segment center frames.
182+ def __getitem__ (self , idx : int ) -> Union [
183+ Tuple [List [Image .Image ], Union [int , List [int ]]],
184+ Tuple ['torch.Tensor[num_frames, channels, height, width]' , Union [int , List [int ]]],
185+ Tuple [Any , Union [int , List [int ]]],
186+ ]:
198187 """
199-
200- tick = (record .num_frames - self .frames_per_segment + 1 ) / float (self .num_segments )
201-
202- offsets = np .array ([int (tick / 2.0 + tick * x ) for x in range (self .num_segments )])
203-
204- return offsets
205-
206- def __getitem__ (self , index ):
207- """
208- For video with id index, loads self.NUM_SEGMENTS * self.FRAMES_PER_SEGMENT
209- frames from evenly chosen locations.
188+ For video with id idx, loads self.NUM_SEGMENTS * self.FRAMES_PER_SEGMENT
189+ frames from evenly chosen locations across the video.
210190
211191 Args:
212- index : Video sample index.
192+ idx : Video sample index.
213193 Returns:
214- a list of PIL images or the result
215- of applying self.transform on this list if
216- self.transform is not None.
194+ A tuple of (video, label). Label is either a single
195+ integer or a list of integers in the case of multiple labels.
196+ Video is either 1) a list of PIL images if no transform is used
197+ 2) a batch of shape (NUM_IMAGES x CHANNELS x HEIGHT x WIDTH) in the range [0,1]
198+ if the transform "ImglistToTensor" is used
199+ 3) or anything else if a custom transform is used.
217200 """
218- record = self .video_list [index ]
201+ record : VideoRecord = self .video_list [idx ]
219202
220- if not self .test_mode :
221- segment_indices = self ._sample_indices (record ) if self .random_shift else self ._get_val_indices (record )
222- else :
223- segment_indices = self ._get_test_indices (record )
203+ frame_start_indices : 'np.ndarray[int]' = self ._get_start_indices (record )
224204
225- return self ._get (record , segment_indices )
205+ return self ._get (record , frame_start_indices )
226206
227- def _get (self , record , indices ):
207+ def _get (self , record : VideoRecord , frame_start_indices : 'np.ndarray[int]' ) -> Union [
208+ Tuple [List [Image .Image ], Union [int , List [int ]]],
209+ Tuple ['torch.Tensor[num_frames, channels, height, width]' , Union [int , List [int ]]],
210+ Tuple [Any , Union [int , List [int ]]],
211+ ]:
228212 """
229213 Loads the frames of a video at the corresponding
230214 indices.
231215
232216 Args:
233217 record: VideoRecord denoting a video sample.
234- indices : Indices at which to load video frames from.
218+ frame_start_indices : Indices from which to load consecutive frames from.
235219 Returns:
236- 1) A list of PIL images or the result
237- of applying self.transform on this list if
238- self.transform is not None.
239- 2) An integer denoting the video label.
220+ A tuple of (video, label). Label is either a single
221+ integer or a list of integers in the case of multiple labels.
222+ Video is either 1) a list of PIL images if no transform is used
223+ 2) a batch of shape (NUM_IMAGES x CHANNELS x HEIGHT x WIDTH) in the range [0,1]
224+ if the transform "ImglistToTensor" is used
225+ 3) or anything else if a custom transform is used.
240226 """
241227
242- indices = indices + record .start_frame
228+ frame_start_indices = frame_start_indices + record .start_frame
243229 images = list ()
244- image_indices = list ()
245- for seg_ind in indices :
246- frame_index = int (seg_ind )
247- for i in range (self .frames_per_segment ):
248- seg_img = self ._load_image (record .path , frame_index )
249- images .extend (seg_img )
250- image_indices .append (frame_index )
230+
231+ # from each start_index, load self.frames_per_segment
232+ # consecutive frames
233+ for start_index in frame_start_indices :
234+ frame_index = int (start_index )
235+
236+ # load self.frames_per_segment consecutive frames
237+ for _ in range (self .frames_per_segment ):
238+ image = self ._load_image (record .path , frame_index )
239+ images .append (image )
240+
251241 if frame_index < record .end_frame :
252242 frame_index += 1
253243
254- # sort images by index in case of edge cases where segments overlap each other because the overall
255- # video is too short for num_segments*frames_per_segment indices.
256- # _, images = (list(sorted_list) for sorted_list in zip(*sorted(zip(image_indices, images))))
257-
258244 if self .transform is not None :
259245 images = self .transform (images )
260246
@@ -269,7 +255,8 @@ class ImglistToTensor(torch.nn.Module):
269255 of shape (NUM_IMAGES x CHANNELS x HEIGHT x WIDTH) in the range [0,1].
270256 Can be used as first transform for ``VideoFrameDataset``.
271257 """
272- def forward (self , img_list ):
258+ @staticmethod
259+ def forward (img_list : List [Image .Image ]) -> 'torch.Tensor[NUM_IMAGES, CHANNELS, HEIGHT, WIDTH]' :
273260 """
274261 Converts each PIL image in a list to
275262 a torch Tensor and stacks them into
0 commit comments