Skip to content

Commit 8bae5f1

Browse files
committed
General cleanup in model
1 parent 30ac446 commit 8bae5f1

15 files changed

+80
-81
lines changed

model/constants.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
# number of epochs to wait before adding the melody loss
1717
MELODY_EPOCH_DELAY = 0
1818

19+
1920
# inverse sigmoid decay
2021
def sampling_rate_at_epoch(epoch):
2122
if epoch < 0:
2223
return START_SCHEDULED_SAMPLING_RATE
23-
return (SCHEDULED_SAMPLING_CONVERGENCE / (SCHEDULED_SAMPLING_CONVERGENCE + math.exp(epoch / SCHEDULED_SAMPLING_CONVERGENCE))) \
24-
* (START_SCHEDULED_SAMPLING_RATE - END_SCHEDULED_SAMPLING_RATE) + END_SCHEDULED_SAMPLING_RATE
24+
return (SCHEDULED_SAMPLING_CONVERGENCE / (
25+
SCHEDULED_SAMPLING_CONVERGENCE + math.exp(epoch / SCHEDULED_SAMPLING_CONVERGENCE))) * (
26+
START_SCHEDULED_SAMPLING_RATE - END_SCHEDULED_SAMPLING_RATE) + END_SCHEDULED_SAMPLING_RATE
27+
2528

2629
HIDDEN_SIZE = 100
2730
HIDDEN_SIZE2 = 32

model/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
23
from model.constants import *
34

45

@@ -44,7 +45,7 @@ def process_sample(json_file):
4445
num_chords = num_measures * CHORD_DISCRETIZATION_LENGTH
4546

4647
chords_list, note_list, num_chords = discretize_sample(json_chords, json_notes, octave_boundary_lower,
47-
num_chords, num_measures * beats_per_measure)
48+
num_chords, num_measures * beats_per_measure)
4849

4950
# pad chord and melodies to max measure length
5051
chords_list.append(CHORD_END_TOKEN)

model/embeddings.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import json
2+
import os
3+
14
import numpy as np
25
import torch
3-
from transformers import BertTokenizer, BertModel
46
from torch.nn.utils.rnn import pad_sequence
5-
import json
6-
import os
7+
from transformers import BertTokenizer, BertModel
78

89
tokenizer = None
910
model = None
@@ -53,4 +54,4 @@ def make_embedding(lyrics, custom_device=None):
5354
output = model(**encoded_input)
5455
embedding = output.last_hidden_state[0]
5556
length = output.last_hidden_state.shape[1]
56-
return embedding, length
57+
return embedding, length

model/lofi2lofi_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
2+
23
import torch
34
from torch.utils.data import Dataset
5+
46
from model.dataset import *
57

68

model/lofi2lofi_model.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from hashlib import md5
2+
13
import numpy as np
24
import torch
35
from torch import nn
4-
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5-
from hashlib import md5
6+
from torch.nn.utils.rnn import pack_padded_sequence
67

78
from model.constants import *
89

@@ -16,7 +17,8 @@ def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
1617
self.mean_linear = nn.Linear(in_features=HIDDEN_SIZE, out_features=HIDDEN_SIZE)
1718
self.variance_linear = nn.Linear(in_features=HIDDEN_SIZE, out_features=HIDDEN_SIZE)
1819

19-
def forward(self, gt_chords, gt_melodies, gt_tempo, gt_key, gt_mode, gt_valence, gt_energy, batch_num_chords, num_chords, sampling_rate_chords=0, sampling_rate_melodies=0):
20+
def forward(self, gt_chords, gt_melodies, gt_tempo, gt_key, gt_mode, gt_valence, gt_energy, batch_num_chords,
21+
num_chords, sampling_rate_chords=0, sampling_rate_melodies=0):
2022
# encode
2123
h = self.encoder(gt_chords, gt_melodies, gt_tempo, gt_key, gt_mode, gt_valence, gt_energy, batch_num_chords)
2224
# VAE
@@ -52,30 +54,29 @@ def __init__(self, device):
5254
super(Encoder, self).__init__()
5355
self.device = device
5456
self.chord_embeddings = nn.Embedding(num_embeddings=CHORD_PREDICTION_LENGTH, embedding_dim=HIDDEN_SIZE)
55-
self.chords_lstm = nn.LSTM(input_size=HIDDEN_SIZE, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, bidirectional=True, batch_first=True)
57+
self.chords_lstm = nn.LSTM(input_size=HIDDEN_SIZE, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS,
58+
bidirectional=True, batch_first=True)
5659

5760
self.melody_embeddings = nn.Embedding(num_embeddings=MELODY_PREDICTION_LENGTH, embedding_dim=HIDDEN_SIZE)
58-
self.melody_lstm = nn.LSTM(input_size=HIDDEN_SIZE, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, bidirectional=True, batch_first=True)
61+
self.melody_lstm = nn.LSTM(input_size=HIDDEN_SIZE, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS,
62+
bidirectional=True, batch_first=True)
5963

6064
self.tempo_embedding = nn.Linear(in_features=1, out_features=HIDDEN_SIZE2)
6165
self.key_embedding = nn.Embedding(num_embeddings=NUMBER_OF_KEYS, embedding_dim=HIDDEN_SIZE2)
6266
self.mode_embedding = nn.Embedding(num_embeddings=NUMBER_OF_MODES, embedding_dim=HIDDEN_SIZE2)
6367
self.valence_embedding = nn.Linear(in_features=1, out_features=HIDDEN_SIZE2)
6468
self.energy_embedding = nn.Linear(in_features=1, out_features=HIDDEN_SIZE2)
6569

66-
self.downsample = nn.Linear(in_features=4*HIDDEN_SIZE + 5*HIDDEN_SIZE2, out_features=HIDDEN_SIZE)
67-
70+
self.downsample = nn.Linear(in_features=4 * HIDDEN_SIZE + 5 * HIDDEN_SIZE2, out_features=HIDDEN_SIZE)
6871

6972
def forward(self, chords, melodies, tempo, key, mode, valence, energy, batch_num_chords):
7073
chord_embeddings = self.chord_embeddings(chords)
7174
chords_input = pack_padded_sequence(chord_embeddings, batch_num_chords, batch_first=True, enforce_sorted=False)
7275
chords_out, (h_chords, _) = self.chords_lstm(chords_input)
73-
# chords_out_repeated = pad_packed_sequence(chords_out, batch_first=True)[0].repeat_interleave( NOTES_PER_CHORD, 1)
74-
# chords_out_repeated = chords_out_repeated[:,:,:HIDDEN_SIZE] + chords_out_repeated[:,:,HIDDEN_SIZE:]
7576

76-
# add two directions together
77-
melody_embeddings = self.melody_embeddings(melodies)# + chords_out_repeated
78-
melody_input = pack_padded_sequence(melody_embeddings, batch_num_chords * NOTES_PER_CHORD, batch_first=True, enforce_sorted=False)
77+
melody_embeddings = self.melody_embeddings(melodies)
78+
melody_input = pack_padded_sequence(melody_embeddings, batch_num_chords * NOTES_PER_CHORD, batch_first=True,
79+
enforce_sorted=False)
7980
_, (h_melodies, _) = self.melody_lstm(melody_input)
8081

8182
tempo_embedding = self.tempo_embedding(tempo.unsqueeze(1).float())
@@ -85,7 +86,9 @@ def forward(self, chords, melodies, tempo, key, mode, valence, energy, batch_num
8586
energy_embedding = self.energy_embedding(energy.unsqueeze(1).float())
8687

8788
h_concatenated = torch.cat((h_chords[-1], h_chords[-2], h_melodies[-1], h_melodies[-2]), dim=1)
88-
return self.downsample(torch.cat((h_concatenated, tempo_embedding, key_embedding, mode_embedding, valence_embedding, energy_embedding), dim=1))
89+
return self.downsample(torch.cat(
90+
(h_concatenated, tempo_embedding, key_embedding, mode_embedding, valence_embedding, energy_embedding),
91+
dim=1))
8992

9093

9194
class Decoder(nn.Module):
@@ -100,7 +103,7 @@ def __init__(self, device):
100103
nn.ReLU(),
101104
nn.Linear(in_features=HIDDEN_SIZE, out_features=CHORD_PREDICTION_LENGTH)
102105
)
103-
self.chord_embedding_downsample = nn.Linear(in_features=2*HIDDEN_SIZE, out_features=HIDDEN_SIZE)
106+
self.chord_embedding_downsample = nn.Linear(in_features=2 * HIDDEN_SIZE, out_features=HIDDEN_SIZE)
104107

105108
self.melody_embeddings = nn.Embedding(num_embeddings=MELODY_PREDICTION_LENGTH, embedding_dim=HIDDEN_SIZE)
106109
self.melody_lstm = nn.LSTMCell(input_size=HIDDEN_SIZE * 1, hidden_size=HIDDEN_SIZE * 1)
@@ -109,7 +112,7 @@ def __init__(self, device):
109112
nn.ReLU(),
110113
nn.Linear(in_features=HIDDEN_SIZE, out_features=MELODY_PREDICTION_LENGTH)
111114
)
112-
self.melody_embedding_downsample = nn.Linear(in_features=3*HIDDEN_SIZE, out_features=HIDDEN_SIZE)
115+
self.melody_embedding_downsample = nn.Linear(in_features=3 * HIDDEN_SIZE, out_features=HIDDEN_SIZE)
113116

114117
self.key_linear = nn.Sequential(
115118
nn.Linear(in_features=HIDDEN_SIZE, out_features=HIDDEN_SIZE2),
@@ -137,30 +140,27 @@ def __init__(self, device):
137140
nn.Linear(in_features=HIDDEN_SIZE2, out_features=1),
138141
)
139142

140-
def generate(self):
141-
mu = torch.randn(1, HIDDEN_SIZE)
142-
return self(mu)
143-
144143
def decode(self, mu):
145144
# create a hash for vector mu
146145
hash = ""
147146
# first 20 characters are each sampled from 5 entries
148147
for i in range(0, 100, 5):
149-
hash += str((mu[0][i:i+1].abs().sum() * 587).int().item())[-1]
148+
hash += str((mu[0][i:i + 1].abs().sum() * 587).int().item())[-1]
150149
# last 4 characters are the beginning of the MD5 hash of the whole vector
151150
hash2 = int(md5(mu.numpy()).hexdigest(), 16)
152151
hash = f"#{hash}{hash2}"[:25]
153152
return hash, self(mu, MAX_CHORD_LENGTH)
154153

155-
def forward(self, z, num_chords=MAX_CHORD_LENGTH, sampling_rate_chords=0, sampling_rate_melodies=0, gt_chords=None, gt_melody=None):
154+
def forward(self, z, num_chords=MAX_CHORD_LENGTH, sampling_rate_chords=0, sampling_rate_melodies=0, gt_chords=None,
155+
gt_melody=None):
156156
tempo_output = self.tempo_linear(z)
157157
key_output = self.key_linear(z)
158158
mode_output = self.mode_linear(z)
159159
valence_output = self.valence_linear(z)
160160
energy_output = self.energy_linear(z)
161161

162162
batch_size = z.shape[0]
163-
# initialize hidden states and cell states randomly
163+
# initialize hidden states and cell states
164164
hx_chords = torch.zeros(batch_size, HIDDEN_SIZE, device=self.device)
165165
cx_chords = torch.zeros(batch_size, HIDDEN_SIZE, device=self.device)
166166
hx_melody = torch.zeros(batch_size, HIDDEN_SIZE, device=self.device)
@@ -205,7 +205,8 @@ def forward(self, z, num_chords=MAX_CHORD_LENGTH, sampling_rate_chords=0, sampli
205205
melody_embeddings = self.melody_embeddings(gt_melody[:, i * NOTES_PER_CHORD + j])
206206
else:
207207
melody_embeddings = self.melody_embeddings(melody_prediction.argmax(dim=1))
208-
melody_embeddings = self.melody_embedding_downsample(torch.cat((melody_embeddings, chord_embeddings, z), dim=1))
208+
melody_embeddings = self.melody_embedding_downsample(
209+
torch.cat((melody_embeddings, chord_embeddings, z), dim=1))
209210

210211
chord_outputs = torch.stack(chord_outputs, dim=1)
211212
melody_outputs = torch.stack(melody_outputs, dim=1)

model/lofi2lofi_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
dataset = Lofi2LofiDataset(dataset_folder, dataset_files)
1212
model = Lofi2LofiModel()
1313

14-
train(dataset, model, "lofi2lofi")
14+
train(dataset, model, "lofi2lofi")

model/lyrics2lofi_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import json
2+
23
import numpy as np
34
import torch
45
from torch.utils.data import Dataset
6+
57
from model.dataset import *
68

79

model/lyrics2lofi_model.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
1414
self.mean_linear = nn.Linear(in_features=HIDDEN_SIZE, out_features=HIDDEN_SIZE)
1515
self.variance_linear = nn.Linear(in_features=HIDDEN_SIZE, out_features=HIDDEN_SIZE)
1616

17-
def forward(self, input, num_chords, sampling_rate_chords=0, sampling_rate_melodies=0, gt_chords=None, gt_melody=None):
17+
def forward(self, input, num_chords=MAX_CHORD_LENGTH, sampling_rate_chords=0, sampling_rate_melodies=0,
18+
gt_chords=None, gt_melody=None):
1819
# encode
1920
h = self.encoder(input)
2021

@@ -50,8 +51,9 @@ class Encoder(nn.Module):
5051
def __init__(self, device):
5152
super(Encoder, self).__init__()
5253
self.device = device
53-
self.encoder_lstm = nn.LSTM(input_size=BERT_EMBEDDING_LENGTH, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, bidirectional=True, batch_first=True)
54-
self.downsample = nn.Linear(in_features=2*HIDDEN_SIZE, out_features=HIDDEN_SIZE)
54+
self.encoder_lstm = nn.LSTM(input_size=BERT_EMBEDDING_LENGTH, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS,
55+
bidirectional=True, batch_first=True)
56+
self.downsample = nn.Linear(in_features=2 * HIDDEN_SIZE, out_features=HIDDEN_SIZE)
5557

5658
def forward(self, x):
5759
_, (h, _) = self.encoder_lstm(x)
@@ -92,7 +94,7 @@ def __init__(self, device):
9294
self.mode_embedding = nn.Linear(in_features=NUMBER_OF_MODES, out_features=HIDDEN_SIZE)
9395
self.valence_embedding = nn.Linear(in_features=1, out_features=HIDDEN_SIZE)
9496
self.energy_embedding = nn.Linear(in_features=1, out_features=HIDDEN_SIZE)
95-
self.downsample = nn.Linear(in_features=5*HIDDEN_SIZE, out_features=HIDDEN_SIZE)
97+
self.downsample = nn.Linear(in_features=5 * HIDDEN_SIZE, out_features=HIDDEN_SIZE)
9698

9799
self.chords_lstm = nn.LSTMCell(input_size=HIDDEN_SIZE * 1, hidden_size=HIDDEN_SIZE * 1)
98100
self.chord_embeddings = nn.Embedding(num_embeddings=CHORD_PREDICTION_LENGTH, embedding_dim=HIDDEN_SIZE)
@@ -101,7 +103,7 @@ def __init__(self, device):
101103
nn.ReLU(),
102104
nn.Linear(in_features=HIDDEN_SIZE, out_features=CHORD_PREDICTION_LENGTH)
103105
)
104-
self.chord_embedding_downsample = nn.Linear(in_features=2*HIDDEN_SIZE, out_features=HIDDEN_SIZE)
106+
self.chord_embedding_downsample = nn.Linear(in_features=2 * HIDDEN_SIZE, out_features=HIDDEN_SIZE)
105107

106108
self.melody_embeddings = nn.Embedding(num_embeddings=MELODY_PREDICTION_LENGTH, embedding_dim=HIDDEN_SIZE)
107109
self.melody_lstm = nn.LSTMCell(input_size=HIDDEN_SIZE * 1, hidden_size=HIDDEN_SIZE * 1)
@@ -110,7 +112,7 @@ def __init__(self, device):
110112
nn.ReLU(),
111113
nn.Linear(in_features=HIDDEN_SIZE, out_features=MELODY_PREDICTION_LENGTH)
112114
)
113-
self.melody_embedding_downsample = nn.Linear(in_features=3*HIDDEN_SIZE, out_features=HIDDEN_SIZE)
115+
self.melody_embedding_downsample = nn.Linear(in_features=3 * HIDDEN_SIZE, out_features=HIDDEN_SIZE)
114116

115117
def forward(self, z, num_chords, sampling_rate_chords=0, sampling_rate_melodies=0, gt_chords=None, gt_melody=None):
116118
tempo_output = self.tempo_linear(z)
@@ -125,7 +127,7 @@ def forward(self, z, num_chords, sampling_rate_chords=0, sampling_rate_melodies=
125127
z = self.downsample(torch.cat((z, tempo_embedding, mode_embedding, valence_embedding, energy_embedding), dim=1))
126128

127129
batch_size = z.shape[0]
128-
# initialize hidden states and cell states randomly
130+
# initialize hidden states and cell states
129131
hx_chords = torch.zeros(batch_size, HIDDEN_SIZE, device=self.device)
130132
cx_chords = torch.zeros(batch_size, HIDDEN_SIZE, device=self.device)
131133
hx_melody = torch.zeros(batch_size, HIDDEN_SIZE, device=self.device)

model/lyrics2lofi_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
dataset = Lyrics2LofiDataset(dataset_folder, dataset_files, embeddings_file, embedding_lengths_file)
1414
model = Lyrics2LofiModel()
1515

16-
train(dataset, model, "lyrics2lofi")
16+
train(dataset, model, "lyrics2lofi")

model/train.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ def compute_loss(data):
5353
model(input, max_num_chords, sampling_rate_chords, sampling_rate_melodies, chords_gt, notes_gt)
5454
else:
5555
pred_chords, pred_notes, pred_tempo, pred_key, pred_mode, pred_valence, pred_energy, kl = \
56-
model(chords_gt, notes_gt, tempo_gt, key_gt, mode_gt, valence_gt, energy_gt, num_chords, max_num_chords, sampling_rate_chords, sampling_rate_melodies)
56+
model(chords_gt, notes_gt, tempo_gt, key_gt, mode_gt, valence_gt, energy_gt, num_chords, max_num_chords,
57+
sampling_rate_chords, sampling_rate_melodies)
5758

59+
# compute a boolean mask to select entries up to a specific index
5860
def compute_mask(max_length, curr_length):
5961
arange = torch.arange(max_length, device=device).repeat((chords_gt.shape[0], 1)).permute(0, 1)
6062
lengths_stacked = curr_length.repeat((max_length, 1)).permute(1, 0)
@@ -70,7 +72,7 @@ def compute_mask(max_length, curr_length):
7072
mask_melody = compute_mask(max_num_notes, num_notes)
7173
loss_melody = torch.masked_select(loss_melody_notes, mask_melody).mean()
7274

73-
if (epoch < MELODY_EPOCH_DELAY):
75+
if epoch < MELODY_EPOCH_DELAY:
7476
loss_melody = 0
7577

7678
loss_kl = kl
@@ -107,9 +109,9 @@ def compute_mask(max_length, curr_length):
107109
# TRAINING
108110
model.train()
109111
for batch, data in enumerate(train_dataloader):
110-
loss, loss_chords, kl_loss, loss_melody,\
111-
loss_tempo, loss_key, loss_mode, loss_valence, loss_energy,\
112-
batch_tp_chords, batch_tp_melodies = compute_loss(data)
112+
loss, loss_chords, kl_loss, loss_melody, \
113+
loss_tempo, loss_key, loss_mode, loss_valence, loss_energy, \
114+
batch_tp_chords, batch_tp_melodies = compute_loss(data)
113115

114116
ep_train_losses_chords.append(loss_chords)
115117
ep_train_losses_melodies.append(loss_melody)
@@ -130,9 +132,9 @@ def compute_mask(max_length, curr_length):
130132
model.eval()
131133
for batch, data in enumerate(val_dataloader):
132134
with torch.no_grad():
133-
loss, loss_chords, kl_loss, loss_melody,\
134-
loss_tempo, loss_key, loss_mode, loss_valence, loss_energy,\
135-
batch_tp_chords, batch_tp_melodies = compute_loss(data)
135+
loss, loss_chords, kl_loss, loss_melody, \
136+
loss_tempo, loss_key, loss_mode, loss_valence, loss_energy, \
137+
batch_tp_chords, batch_tp_melodies = compute_loss(data)
136138

137139
ep_val_losses_chords.append(loss_chords)
138140
ep_val_losses_melodies.append(loss_melody)

0 commit comments

Comments
 (0)