-
Notifications
You must be signed in to change notification settings - Fork 1
/
OF.py
73 lines (60 loc) · 2.4 KB
/
OF.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
import torch
import cv2
from raft import RAFTOpticalFlow
import argparse
from collections import OrderedDict
from tqdm import tqdm
import os
class RaftOF:
"""
This class hadles the RAFT model to quickly compute the optical flow on videos
"""
def __init__(self):
"""
The raft model gets loaded onto the available device and put into evaluation mode
"""
self.raft_of = RAFTOpticalFlow()
def get_optical_flow(self, frame1, frame2):
"""
:param frame1: np.ndarray(n, m, 3), first image
:param frame2: np.ndarray(n, m, 3), second image
:return: np.ndarray(n, m, 2), optical flow between the input images estimated using raft
"""
return self.raft_of.calc(frame1, frame2)
def compute_optical_flow_on_video(self, video_path, saving_path=None):
"""
:param video_path: str, path to read video from
:param saving_path: str or None, path to write frames and optical flow arrays to, if None, a new path is
created in this directory with the video name
:return: ((int, int), int, str), returns the image size as well as the number of frames from the video
as well as the path where all frames are saved
"""
if saving_path is None:
video_name = video_path.split('/')[-1].split('\\')[-1].split('.')[0]
saving_path = os.path.join('./', video_name)
if not os.path.isdir(saving_path):
os.mkdir(saving_path)
print('Reading Video..')
cap = cv2.VideoCapture(video_path)
readable = True
frame_list = []
while readable:
readable, frame = cap.read()
if readable:
frame_list.append(frame.copy())
n = len(frame_list)
flows = []
prog_bar = tqdm(total = n)
for i in range(n):
flows.append(self.get_optical_flow(frame_list[0], frame_list[i]))
prog_bar.update()
for i in range(n):
np.save(saving_path + '/img_{}'.format(i), frame_list[i])
if i >= 1:
flow = flows[i] - flows[i-1]
else:
flow = flows[i]
np.save(saving_path + '/flow_{}'.format(i), flow)
prog_bar.update()
return (frame_list[0].shape[:2]), n, saving_path