1
+ from hashlib import md5
2
+
1
3
import numpy as np
2
4
import torch
3
5
from torch import nn
4
- from torch .nn .utils .rnn import pack_padded_sequence , pad_packed_sequence
5
- from hashlib import md5
6
+ from torch .nn .utils .rnn import pack_padded_sequence
6
7
7
8
from model .constants import *
8
9
@@ -16,7 +17,8 @@ def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
16
17
self .mean_linear = nn .Linear (in_features = HIDDEN_SIZE , out_features = HIDDEN_SIZE )
17
18
self .variance_linear = nn .Linear (in_features = HIDDEN_SIZE , out_features = HIDDEN_SIZE )
18
19
19
- def forward (self , gt_chords , gt_melodies , gt_tempo , gt_key , gt_mode , gt_valence , gt_energy , batch_num_chords , num_chords , sampling_rate_chords = 0 , sampling_rate_melodies = 0 ):
20
+ def forward (self , gt_chords , gt_melodies , gt_tempo , gt_key , gt_mode , gt_valence , gt_energy , batch_num_chords ,
21
+ num_chords , sampling_rate_chords = 0 , sampling_rate_melodies = 0 ):
20
22
# encode
21
23
h = self .encoder (gt_chords , gt_melodies , gt_tempo , gt_key , gt_mode , gt_valence , gt_energy , batch_num_chords )
22
24
# VAE
@@ -52,30 +54,29 @@ def __init__(self, device):
52
54
super (Encoder , self ).__init__ ()
53
55
self .device = device
54
56
self .chord_embeddings = nn .Embedding (num_embeddings = CHORD_PREDICTION_LENGTH , embedding_dim = HIDDEN_SIZE )
55
- self .chords_lstm = nn .LSTM (input_size = HIDDEN_SIZE , hidden_size = HIDDEN_SIZE , num_layers = NUM_LAYERS , bidirectional = True , batch_first = True )
57
+ self .chords_lstm = nn .LSTM (input_size = HIDDEN_SIZE , hidden_size = HIDDEN_SIZE , num_layers = NUM_LAYERS ,
58
+ bidirectional = True , batch_first = True )
56
59
57
60
self .melody_embeddings = nn .Embedding (num_embeddings = MELODY_PREDICTION_LENGTH , embedding_dim = HIDDEN_SIZE )
58
- self .melody_lstm = nn .LSTM (input_size = HIDDEN_SIZE , hidden_size = HIDDEN_SIZE , num_layers = NUM_LAYERS , bidirectional = True , batch_first = True )
61
+ self .melody_lstm = nn .LSTM (input_size = HIDDEN_SIZE , hidden_size = HIDDEN_SIZE , num_layers = NUM_LAYERS ,
62
+ bidirectional = True , batch_first = True )
59
63
60
64
self .tempo_embedding = nn .Linear (in_features = 1 , out_features = HIDDEN_SIZE2 )
61
65
self .key_embedding = nn .Embedding (num_embeddings = NUMBER_OF_KEYS , embedding_dim = HIDDEN_SIZE2 )
62
66
self .mode_embedding = nn .Embedding (num_embeddings = NUMBER_OF_MODES , embedding_dim = HIDDEN_SIZE2 )
63
67
self .valence_embedding = nn .Linear (in_features = 1 , out_features = HIDDEN_SIZE2 )
64
68
self .energy_embedding = nn .Linear (in_features = 1 , out_features = HIDDEN_SIZE2 )
65
69
66
- self .downsample = nn .Linear (in_features = 4 * HIDDEN_SIZE + 5 * HIDDEN_SIZE2 , out_features = HIDDEN_SIZE )
67
-
70
+ self .downsample = nn .Linear (in_features = 4 * HIDDEN_SIZE + 5 * HIDDEN_SIZE2 , out_features = HIDDEN_SIZE )
68
71
69
72
def forward (self , chords , melodies , tempo , key , mode , valence , energy , batch_num_chords ):
70
73
chord_embeddings = self .chord_embeddings (chords )
71
74
chords_input = pack_padded_sequence (chord_embeddings , batch_num_chords , batch_first = True , enforce_sorted = False )
72
75
chords_out , (h_chords , _ ) = self .chords_lstm (chords_input )
73
- # chords_out_repeated = pad_packed_sequence(chords_out, batch_first=True)[0].repeat_interleave( NOTES_PER_CHORD, 1)
74
- # chords_out_repeated = chords_out_repeated[:,:,:HIDDEN_SIZE] + chords_out_repeated[:,:,HIDDEN_SIZE:]
75
76
76
- # add two directions together
77
- melody_embeddings = self . melody_embeddings ( melodies ) # + chords_out_repeated
78
- melody_input = pack_padded_sequence ( melody_embeddings , batch_num_chords * NOTES_PER_CHORD , batch_first = True , enforce_sorted = False )
77
+ melody_embeddings = self . melody_embeddings ( melodies )
78
+ melody_input = pack_padded_sequence ( melody_embeddings , batch_num_chords * NOTES_PER_CHORD , batch_first = True ,
79
+ enforce_sorted = False )
79
80
_ , (h_melodies , _ ) = self .melody_lstm (melody_input )
80
81
81
82
tempo_embedding = self .tempo_embedding (tempo .unsqueeze (1 ).float ())
@@ -85,7 +86,9 @@ def forward(self, chords, melodies, tempo, key, mode, valence, energy, batch_num
85
86
energy_embedding = self .energy_embedding (energy .unsqueeze (1 ).float ())
86
87
87
88
h_concatenated = torch .cat ((h_chords [- 1 ], h_chords [- 2 ], h_melodies [- 1 ], h_melodies [- 2 ]), dim = 1 )
88
- return self .downsample (torch .cat ((h_concatenated , tempo_embedding , key_embedding , mode_embedding , valence_embedding , energy_embedding ), dim = 1 ))
89
+ return self .downsample (torch .cat (
90
+ (h_concatenated , tempo_embedding , key_embedding , mode_embedding , valence_embedding , energy_embedding ),
91
+ dim = 1 ))
89
92
90
93
91
94
class Decoder (nn .Module ):
@@ -100,7 +103,7 @@ def __init__(self, device):
100
103
nn .ReLU (),
101
104
nn .Linear (in_features = HIDDEN_SIZE , out_features = CHORD_PREDICTION_LENGTH )
102
105
)
103
- self .chord_embedding_downsample = nn .Linear (in_features = 2 * HIDDEN_SIZE , out_features = HIDDEN_SIZE )
106
+ self .chord_embedding_downsample = nn .Linear (in_features = 2 * HIDDEN_SIZE , out_features = HIDDEN_SIZE )
104
107
105
108
self .melody_embeddings = nn .Embedding (num_embeddings = MELODY_PREDICTION_LENGTH , embedding_dim = HIDDEN_SIZE )
106
109
self .melody_lstm = nn .LSTMCell (input_size = HIDDEN_SIZE * 1 , hidden_size = HIDDEN_SIZE * 1 )
@@ -109,7 +112,7 @@ def __init__(self, device):
109
112
nn .ReLU (),
110
113
nn .Linear (in_features = HIDDEN_SIZE , out_features = MELODY_PREDICTION_LENGTH )
111
114
)
112
- self .melody_embedding_downsample = nn .Linear (in_features = 3 * HIDDEN_SIZE , out_features = HIDDEN_SIZE )
115
+ self .melody_embedding_downsample = nn .Linear (in_features = 3 * HIDDEN_SIZE , out_features = HIDDEN_SIZE )
113
116
114
117
self .key_linear = nn .Sequential (
115
118
nn .Linear (in_features = HIDDEN_SIZE , out_features = HIDDEN_SIZE2 ),
@@ -137,30 +140,27 @@ def __init__(self, device):
137
140
nn .Linear (in_features = HIDDEN_SIZE2 , out_features = 1 ),
138
141
)
139
142
140
- def generate (self ):
141
- mu = torch .randn (1 , HIDDEN_SIZE )
142
- return self (mu )
143
-
144
143
def decode (self , mu ):
145
144
# create a hash for vector mu
146
145
hash = ""
147
146
# first 20 characters are each sampled from 5 entries
148
147
for i in range (0 , 100 , 5 ):
149
- hash += str ((mu [0 ][i :i + 1 ].abs ().sum () * 587 ).int ().item ())[- 1 ]
148
+ hash += str ((mu [0 ][i :i + 1 ].abs ().sum () * 587 ).int ().item ())[- 1 ]
150
149
# last 4 characters are the beginning of the MD5 hash of the whole vector
151
150
hash2 = int (md5 (mu .numpy ()).hexdigest (), 16 )
152
151
hash = f"#{ hash } { hash2 } " [:25 ]
153
152
return hash , self (mu , MAX_CHORD_LENGTH )
154
153
155
- def forward (self , z , num_chords = MAX_CHORD_LENGTH , sampling_rate_chords = 0 , sampling_rate_melodies = 0 , gt_chords = None , gt_melody = None ):
154
+ def forward (self , z , num_chords = MAX_CHORD_LENGTH , sampling_rate_chords = 0 , sampling_rate_melodies = 0 , gt_chords = None ,
155
+ gt_melody = None ):
156
156
tempo_output = self .tempo_linear (z )
157
157
key_output = self .key_linear (z )
158
158
mode_output = self .mode_linear (z )
159
159
valence_output = self .valence_linear (z )
160
160
energy_output = self .energy_linear (z )
161
161
162
162
batch_size = z .shape [0 ]
163
- # initialize hidden states and cell states randomly
163
+ # initialize hidden states and cell states
164
164
hx_chords = torch .zeros (batch_size , HIDDEN_SIZE , device = self .device )
165
165
cx_chords = torch .zeros (batch_size , HIDDEN_SIZE , device = self .device )
166
166
hx_melody = torch .zeros (batch_size , HIDDEN_SIZE , device = self .device )
@@ -205,7 +205,8 @@ def forward(self, z, num_chords=MAX_CHORD_LENGTH, sampling_rate_chords=0, sampli
205
205
melody_embeddings = self .melody_embeddings (gt_melody [:, i * NOTES_PER_CHORD + j ])
206
206
else :
207
207
melody_embeddings = self .melody_embeddings (melody_prediction .argmax (dim = 1 ))
208
- melody_embeddings = self .melody_embedding_downsample (torch .cat ((melody_embeddings , chord_embeddings , z ), dim = 1 ))
208
+ melody_embeddings = self .melody_embedding_downsample (
209
+ torch .cat ((melody_embeddings , chord_embeddings , z ), dim = 1 ))
209
210
210
211
chord_outputs = torch .stack (chord_outputs , dim = 1 )
211
212
melody_outputs = torch .stack (melody_outputs , dim = 1 )
0 commit comments