Skip to content

Commit 0eaa9dc

Browse files
jmvalinTheGrumpySnail01
authored andcommitted
Adds --chunks-per-offset option to train_rdovae.py
1 parent 499a509 commit 0eaa9dc

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

dnn/torch/rdovae/rdovae/rdovae.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ def __init__(self,
551551
cond_size2,
552552
state_dim=24,
553553
split_mode='split',
554+
chunks_per_offset=4,
554555
clip_weights=False,
555556
pvq_num_pulses=82,
556557
state_dropout_rate=0,
@@ -564,6 +565,7 @@ def __init__(self,
564565
self.cond_size = cond_size
565566
self.cond_size2 = cond_size2
566567
self.split_mode = split_mode
568+
self.chunks_per_offset = chunks_per_offset
567569
self.state_dim = state_dim
568570
self.pvq_num_pulses = pvq_num_pulses
569571
self.state_dropout_rate = state_dropout_rate
@@ -670,7 +672,7 @@ def forward(self, features, q_id):
670672
states_q = states_q * mask
671673

672674
# decoder
673-
chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode)
675+
chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode, chunks_per_offset=self.chunks_per_offset)
674676

675677
outputs_hq = []
676678
outputs_sq = []

dnn/torch/rdovae/train_rdovae.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@
6060
training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32)
6161
training_group.add_argument('--lr', type=float, help='learning rate, default: 3e-4', default=3e-4)
6262
training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 100', default=100)
63-
training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by 4, default: 256', default=256)
63+
training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by chunks_per_offset, default: 400', default=400)
64+
training_group.add_argument('--chunks-per-offset', type=int, help='chunks per offset', default=4)
6465
training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5)
6566
training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split'], help='splitting mode for decoder input, default: split', default='split')
6667
training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames')
@@ -120,7 +121,7 @@
120121

121122
# model
122123
checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
123-
checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate, 'softquant': softquant}
124+
checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate, 'softquant': softquant, 'chunks_per_offset': args.chunks_per_offset}
124125
model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
125126

126127
if type(args.initial_checkpoint) != type(None):

0 commit comments

Comments
 (0)