diff --git a/aria/datasets.py b/aria/datasets.py index 210c9da..b1d3b51 100644 --- a/aria/datasets.py +++ b/aria/datasets.py @@ -1060,7 +1060,7 @@ def _get_onset_adjusted_msg( return _temp_note_msg - _note_msgs = midi_dict.note_msgs + _note_msgs = copy.deepcopy(midi_dict.note_msgs) # Remove notes if random.random() < config["remove_notes"]["activation_prob"]: @@ -1270,10 +1270,6 @@ def _build_epoch(_save_path, _midi_dataset): if _idx % 250 == 0: logger.info(f"Finished processing {_idx}") - # DEBUG - if _idx == 1000: - break - logger = setup_logger() assert max_seq_len > 0, "max_seq_len must be greater than 0" assert num_epochs > 0, "num_epochs must be greater than 0" diff --git a/aria/run.py b/aria/run.py index cfb4bc5..182d1f5 100644 --- a/aria/run.py +++ b/aria/run.py @@ -152,17 +152,20 @@ def sample(args): guidance_midi_dict=guidance_midi_dict, ) - if guidance_seq: - tokenizer.detokenize(guidance_seq).to_midi().save( - os.path.join(samples_dir, f"guidance.mid") - ) if len(prompt_seq) + args.l > model_config.max_seq_len: print( "WARNING: Required context exceeds max_seq_len supported by model" ) prompts = [prompt_seq for _ in range(num_variations)] - if args.cfg is not None: + samples_dir = os.path.join(os.path.dirname(__file__), "..", "samples") + if os.path.isdir(samples_dir) is False: + os.mkdir(samples_dir) + if guidance_seq: + tokenizer.detokenize(guidance_seq).to_midi().save( + os.path.join(samples_dir, f"guidance.mid") + ) + if args.cfg is not None and guidance_seq is not None: results = sample_batch_cfg( model=model, tokenizer=tokenizer, @@ -186,10 +189,6 @@ def sample(args): compile=args.compile, ) - samples_dir = os.path.join(os.path.dirname(__file__), "..", "samples") - if os.path.isdir(samples_dir) is False: - os.mkdir(samples_dir) - for idx, tokenized_seq in enumerate(results): res_midi_dict = tokenizer.detokenize(tokenized_seq) res_midi = res_midi_dict.to_midi() diff --git a/aria/sample.py b/aria/sample.py index 4546283..01f9abe 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -303,6 +303,7 @@ def sample_batch_cfg( logits_cfg[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( "-inf" ) + logits_cfg[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") if temperature > 0.0: probs = torch.softmax(logits_cfg / temperature, dim=-1) @@ -389,7 +390,10 @@ def get_inference_prompt( if tokenizer.dim_tok in prompt_seq: prompt_seq.remove(tokenizer.dim_tok) - if guidance_midi_dict is not None: + if ( + guidance_midi_dict is not None + and tokenizer.guidance_start_tok in prompt_seq + ): guidance_seq = copy.deepcopy(prompt_seq) guidance_seq = guidance_seq[ : guidance_seq.index(tokenizer.guidance_end_tok) diff --git a/aria/tokenizer.py b/aria/tokenizer.py index 818fb27..c142405 100644 --- a/aria/tokenizer.py +++ b/aria/tokenizer.py @@ -174,8 +174,8 @@ def detokenize(self, tokenized_seq: list, **kwargs): def export_data_aug(self): return [ - self.export_guidance_tempo_aug(max_tempo_aug=0.25, mixup=True), - self.export_guidance_pitch_aug(4), + self.export_guidance_tempo_aug(max_tempo_aug=0.2, mixup=True), + self.export_guidance_pitch_aug(3), self.export_guidance_velocity_aug(2), ] diff --git a/aria/train.py b/aria/train.py index 6a2c74b..5733e49 100644 --- a/aria/train.py +++ b/aria/train.py @@ -61,9 +61,6 @@ # -bs 32 \ # -workers 8 -# TODO: -# - Test that everything works on a distributed setup - def setup_logger(project_dir: str): # Get logger and reset all handlers @@ -196,7 +193,7 @@ def get_optim( num_epochs: int, steps_per_epoch: int, ): - LR = 3e-5 + LR = 3e-4 END_RATIO = 0.1 WARMUP_STEPS = 200 diff --git a/config/config.json b/config/config.json index 7e0a736..0d32671 100644 --- a/config/config.json +++ b/config/config.json @@ -182,31 +182,29 @@ "min_clean_interval_ms": 60000, "max_clean_interval_ms": 200000, "noising": { - "activation_prob": 0.8, + "activation_prob": 0.5, "remove_notes": { - "activation_prob": 0.5, + "activation_prob": 0.25, "min_ratio": 0.0, - "max_ratio": 0.3 + "max_ratio": 0.15 }, "adjust_velocity": { - "activation_prob": 0.3, + "activation_prob": 0.25, "min_adjust": 1, - "max_adjust": 30, - "max_ratio": 0.1, - "min_ratio": 0.30 + "max_adjust": 20 }, "adjust_onsets": { "activation_prob": 0.25, - "min_adjust_s": 0.01, - "max_adjust_s": 0.07, - "max_ratio": 0.15, - "min_ratio": 0.3 + "min_adjust_s": 0.005, + "max_adjust_s": 0.05, + "max_ratio": 0.0, + "min_ratio": 0.2 }, "quantize_onsets": { - "activation_prob": 0.15, + "activation_prob": 0.05, "min_quant_s": 0.05, - "max_quant_s": 0.15, - "max_vel_delta": 45 + "max_quant_s": 0.1, + "max_vel_delta": 30 } } } @@ -215,7 +213,7 @@ "inference_abs": { "guidance": { "min_ms": 5000, - "max_ms": 30000 + "max_ms": 40000 }