forked from boostcampaitech5/level3_cv_finalproject-cv-08
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
114 lines (83 loc) · 4.14 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import torch
import streamlit as st
from torch.autograd import Variable
from models.video_to_roll import resnet18
from models.roll_to_midi import Generator
from models.make_wav import MIDISynth
@st.cache_resource
def video_to_roll_load_model(device):
model_path = "./data/model/video_to_roll_best_f1_2.pth"
model = resnet18().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model
@st.cache_resource
def roll_to_midi_load_model(device, input_shape):
model_path = "./data/model/roll_to_midi.tar"
weights = torch.load(model_path, map_location=device)
model = Generator(input_shape).cuda()
model.load_state_dict(weights['state_dict_G'])
return model
def video_to_roll_inference(video_info, frames_with5, instrument):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = video_to_roll_load_model(device)
min_key, max_key = 3, 83
threshold = 0.6
batch_size = 32
preds_roll, preds_logit = [], []
for idx in range(0, len(frames_with5), batch_size):
batch_frames = torch.stack([torch.Tensor(np.asarray(frames_with5[i])).float().cuda() for i in range(idx, min(len(frames_with5), idx+batch_size))])
pred_logits = model(batch_frames)
pred_roll = torch.sigmoid(pred_logits) >= threshold
numpy_pred_roll = pred_roll.cpu().detach().numpy().astype(np.int_)
numpy_pred_logit = pred_logits.cpu().detach().numpy()
for roll, logit in zip(numpy_pred_roll, numpy_pred_logit):
preds_roll.append(roll)
preds_logit.append(logit)
preds_roll = np.asarray(preds_roll).squeeze()
if preds_roll.shape[0] != video_info['video_select_frame']:
temp = np.zeros((video_info['video_select_frame'], max_key-min_key+1))
temp[:preds_roll.shape[0], :] = preds_roll[:video_info['video_select_frame'], :]
preds_roll = temp
roll = np.zeros((video_info['video_select_frame'], 88))
roll[:, min_key:max_key+1] = preds_roll
preds_logit = np.asarray(preds_logit).squeeze()
if preds_logit.shape[0] != video_info['video_select_frame']:
temp = np.zeros((video_info['video_select_frame'], max_key-min_key+1))
temp[:preds_logit.shape[0], :] = preds_logit[:video_info['video_select_frame'], :]
preds_logit = temp
logit = np.zeros((video_info['video_select_frame'], 88))
logit[:, min_key:max_key+1] = preds_logit
wav, pm = MIDISynth(roll=roll, midi=None, frame=video_info['video_select_frame'], ins=instrument, is_midi=False).process_roll()
return roll, logit, wav, pm
def roll_to_midi_inference(video_info, logit, instrument):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
min_key, max_key = 15, 65
frame = video_info['video_select_frame'] // 2
input_shape = (1, max_key - min_key + 1, 2 * frame)
model = roll_to_midi_load_model(device, input_shape)
data = [torch.from_numpy(logit[i:i+frame]) for i in range(0, len(logit), frame)]
final_data = []
for i in range(0, len(data), 2):
if i + 1 < len(data):
one_roll = data[i]
two_roll = data[i+1]
final_roll = torch.cat([one_roll, two_roll], dim=0)
final_data.append(final_roll)
results = []
for i, data in enumerate(final_data):
roll = torch.unsqueeze(torch.unsqueeze(torch.sigmoid(data.T.float().cuda()), dim=0), dim=0)
with torch.no_grad():
model.eval()
roll = roll.type(torch.cuda.FloatTensor)
roll_ = Variable(roll)
gen_img = model(roll_)
gen_img = gen_img >= 0.5
numpy_pre_label = gen_img.cpu().detach().numpy().astype(np.int_)
numpy_pre_label = np.transpose(numpy_pre_label.squeeze(), (1, 0))
results.append(numpy_pre_label[:frame, :])
results.append(numpy_pre_label[frame:, :])
midi = np.concatenate(results, axis=0)
wav, pm = MIDISynth(roll=None, midi=midi, frame=midi.shape[0], ins=instrument, is_midi=True).process_midi()
return midi, wav, pm