Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions aria/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ def _get_onset_adjusted_msg(

return _temp_note_msg

_note_msgs = midi_dict.note_msgs
_note_msgs = copy.deepcopy(midi_dict.note_msgs)

# Remove notes
if random.random() < config["remove_notes"]["activation_prob"]:
Expand Down Expand Up @@ -1270,10 +1270,6 @@ def _build_epoch(_save_path, _midi_dataset):
if _idx % 250 == 0:
logger.info(f"Finished processing {_idx}")

# DEBUG
if _idx == 1000:
break

logger = setup_logger()
assert max_seq_len > 0, "max_seq_len must be greater than 0"
assert num_epochs > 0, "num_epochs must be greater than 0"
Expand Down
17 changes: 8 additions & 9 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,20 @@ def sample(args):
guidance_midi_dict=guidance_midi_dict,
)

if guidance_seq:
tokenizer.detokenize(guidance_seq).to_midi().save(
os.path.join(samples_dir, f"guidance.mid")
)
if len(prompt_seq) + args.l > model_config.max_seq_len:
print(
"WARNING: Required context exceeds max_seq_len supported by model"
)
prompts = [prompt_seq for _ in range(num_variations)]

if args.cfg is not None:
samples_dir = os.path.join(os.path.dirname(__file__), "..", "samples")
if os.path.isdir(samples_dir) is False:
os.mkdir(samples_dir)
if guidance_seq:
tokenizer.detokenize(guidance_seq).to_midi().save(
os.path.join(samples_dir, f"guidance.mid")
)
if args.cfg is not None and guidance_seq is not None:
results = sample_batch_cfg(
model=model,
tokenizer=tokenizer,
Expand All @@ -186,10 +189,6 @@ def sample(args):
compile=args.compile,
)

samples_dir = os.path.join(os.path.dirname(__file__), "..", "samples")
if os.path.isdir(samples_dir) is False:
os.mkdir(samples_dir)

for idx, tokenized_seq in enumerate(results):
res_midi_dict = tokenizer.detokenize(tokenized_seq)
res_midi = res_midi_dict.to_midi()
Expand Down
6 changes: 5 additions & 1 deletion aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def sample_batch_cfg(
logits_cfg[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float(
"-inf"
)
logits_cfg[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf")

if temperature > 0.0:
probs = torch.softmax(logits_cfg / temperature, dim=-1)
Expand Down Expand Up @@ -389,7 +390,10 @@ def get_inference_prompt(
if tokenizer.dim_tok in prompt_seq:
prompt_seq.remove(tokenizer.dim_tok)

if guidance_midi_dict is not None:
if (
guidance_midi_dict is not None
and tokenizer.guidance_start_tok in prompt_seq
):
guidance_seq = copy.deepcopy(prompt_seq)
guidance_seq = guidance_seq[
: guidance_seq.index(tokenizer.guidance_end_tok)
Expand Down
4 changes: 2 additions & 2 deletions aria/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def detokenize(self, tokenized_seq: list, **kwargs):

def export_data_aug(self):
return [
self.export_guidance_tempo_aug(max_tempo_aug=0.25, mixup=True),
self.export_guidance_pitch_aug(4),
self.export_guidance_tempo_aug(max_tempo_aug=0.2, mixup=True),
self.export_guidance_pitch_aug(3),
self.export_guidance_velocity_aug(2),
]

Expand Down
5 changes: 1 addition & 4 deletions aria/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@
# -bs 32 \
# -workers 8

# TODO:
# - Test that everything works on a distributed setup


def setup_logger(project_dir: str):
# Get logger and reset all handlers
Expand Down Expand Up @@ -196,7 +193,7 @@ def get_optim(
num_epochs: int,
steps_per_epoch: int,
):
LR = 3e-5
LR = 3e-4
END_RATIO = 0.1
WARMUP_STEPS = 200

Expand Down
28 changes: 13 additions & 15 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -182,31 +182,29 @@
"min_clean_interval_ms": 60000,
"max_clean_interval_ms": 200000,
"noising": {
"activation_prob": 0.8,
"activation_prob": 0.5,
"remove_notes": {
"activation_prob": 0.5,
"activation_prob": 0.25,
"min_ratio": 0.0,
"max_ratio": 0.3
"max_ratio": 0.15
},
"adjust_velocity": {
"activation_prob": 0.3,
"activation_prob": 0.25,
"min_adjust": 1,
"max_adjust": 30,
"max_ratio": 0.1,
"min_ratio": 0.30
"max_adjust": 20
},
"adjust_onsets": {
"activation_prob": 0.25,
"min_adjust_s": 0.01,
"max_adjust_s": 0.07,
"max_ratio": 0.15,
"min_ratio": 0.3
"min_adjust_s": 0.005,
"max_adjust_s": 0.05,
"max_ratio": 0.0,
"min_ratio": 0.2
},
"quantize_onsets": {
"activation_prob": 0.15,
"activation_prob": 0.05,
"min_quant_s": 0.05,
"max_quant_s": 0.15,
"max_vel_delta": 45
"max_quant_s": 0.1,
"max_vel_delta": 30
}
}
}
Expand All @@ -215,7 +213,7 @@
"inference_abs": {
"guidance": {
"min_ms": 5000,
"max_ms": 30000
"max_ms": 40000
}


Expand Down
Loading