diff --git a/.gitignore b/.gitignore index dbc64648..946afa19 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,6 @@ fluidsynth/ tests/test_results lightning_logs/ .vscode/ +paper +hf +_scripts diff --git a/Makefile b/Makefile deleted file mode 100644 index 774c3fc5..00000000 --- a/Makefile +++ /dev/null @@ -1,9 +0,0 @@ -.PHONY: test -test: - python -m unittest tests/test_*.py - - -.PHONY: format -format: - black --line-length 80 ./aria - black --line-length 80 ./tests diff --git a/README.md b/README.md index 7b11af7e..7ed175ab 100644 --- a/README.md +++ b/README.md @@ -1,64 +1,131 @@ -# gpt-aria +# Aria -[Discord](https://discord.com/invite/zBGx3azzUn) +This repository contains training, inference, and evaluation code for the paper [*Scaling Self-Supervised Representation Learning for Symbolic Piano Performance (ISMIR 2025)*](https://example.com/), as well as implementations of our real-time piano continuation demo. *Aria* is a pretrained autoregressive generative model for symbolic music, based on the LLaMA 3.2 (1B) architecture, which was trained on ~60k hours of MIDI transcriptions of expressive solo-piano recordings. Alongside the base model, we are releasing a checkpoint finetuned to improve generative quality, as well as a checkpoint finetuned to produce general-purpose piano MIDI embeddings using a SimCSE-style contrastive training objective. -A repository containing resources for pre-training, fine-tuning, and evaluating musical (MIDI) transformer models. +📖 Read our [release blog post](https://example.com/) and [paper](https://example.com/) +🤗 Access our models via the [HuggingFace page](https://huggingface.co/loubb/aria-medium-base) +📊 Get access to our training dataset [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) and train your own models -***Note that this project is under active development*** +## Installation -## Description +Installation requires Python 3.11+. To install the package and all dependencies with pip: -The main goal of the gpt-aria project is to create a suite of powerful pre-trained generative (symbolic) music models. We want to investigate how modern training (pre-training & fine-tuning) techniques can be used to improve the quality/usefulness of such models. Alongside this we are building various data (MIDI) preprocessing tools, allowing **you** to easily fine-tune our models on your own data. +```bash +git clone https://github.com/EleutherAI/aria +cd aria +pip install -e ".[all]" +``` -If you are new to symbolic music models, a good place to start are the following projects/blogposts by Google Magenta and OpenAI: +## Quickstart -- [Music Transformer](https://magenta.tensorflow.org/music-transformer) -- [MuseNet](https://openai.com/research/musenet) +Download model weights from the official HuggingFace page for our pretrained model, as well as checkpoints finetuned for piano-continuation and generating MIDI-embeddings: - Long story short: Transformer + MIDI + GPUs = 🎵 x ∞ +- `aria-medium-base` ([huggingface](https://huggingface.co/loubb/aria-medium-base), [direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/model.safetensors?download=true)) +- `aria-medium-gen`([huggingface](https://huggingface.co/loubb/aria-medium-gen), [direct-download](https://huggingface.co/loubb/aria-medium-gen/resolve/main/model.safetensors?download=true)) +- `aria-medium-embedding`([huggingface](https://huggingface.co/loubb/aria-medium-embedding), [direct-download](https://huggingface.co/loubb/aria-medium-embedding/resolve/main/model.safetensors?download=true)) -## Installation +### Inference (Prompt Continuation) -Make sure you are using Python 3.10+. Note that I haven't explicitly developed this project for anything other than Linux. If you are using Windows, things might not work properly. In this case I suggest installing using WSL. +We provide optimized model implementations for PyTorch (CUDA) and MLX (Apple Silicon). You can generate continuations of a MIDI file using the CLI, e.g., using CUDA (Linux): -``` -git clone https://github.com/eleutherai/aria -cd aria -pip install -e . +```bash +aria generate \ + --backend torch_cuda \ + --checkpoint_path \ + --prompt_midi_path \ + --prompt_duration \ + --variations \ + --temp 0.98 \ + --min_p 0.035 \ + --length 2048 \ + --save_dir ``` -## Inference +Since the model has not been post-trained with instruction tuning or RLHF (similar to pre-instruct GPT models), it is very sensitive to input quality and performs best when prompted with well-played music. To get prompt MIDI files, see the `example-prompts/` directory, explore the [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) dataset, or transcribe your own files using our [piano-transcription model](https://github.com/EleutherAI/aria-amt). For a full list of sampling options: `aria generate -h`. If you wish to do inference on the CPU, please see the platform-agnostic implementation on our HuggingFace page [link]. -You can find preliminary checkpoints at the following locations +### Intended Use and Limitations -Finetuned piano-only checkpoints (improved robustness): +Aria performs best when **continuing existing piano MIDI files** rather than generating music from scratch. While multi-track tokenization and generation are supported, the model was trained primarily on **single-track expressive piano performances**, and we recommend using single-track inputs for optimal results. -``` -large - https://storage.googleapis.com/aria-checkpoints/large-abs-inst.safetensors -``` +Due to the high representation of popular classical works (e.g., Chopin) in the training data and the difficulty of complete deduplication, the model may **memorize or closely reproduce** such pieces. For more original outputs, we suggest prompting Aria with **lesser-known works or your own compositions**. -Pretrained checkpoints: +### Inference (MIDI embeddings) +You can generate embeddings from MIDI files using the `aria.embeddings` module. This is primarily exposed with the `get_global_embedding_from_midi` function, for example: + +```python +from aria.embeddings import get_global_embedding_from_midi +from aria.model import TransformerEMB, ModelConfig +from aria.config import load_model_config +from ariautils.tokenizer import AbsTokenizer + +# Load model +model_config = ModelConfig(**load_model_config(name="medium-emb")) +model_config.set_vocab_size(AbsTokenizer().vocab_size) +model = TransformerEMB(model_config) +state_dict = load_file(filename=CHECKPOINT_PATH) +model.load_state_dict(state_dict=state_dict, strict=True) + +# Generate embedding +embedding = get_global_embedding_from_midi( + model=model, + midi_path=MIDI_PATH, + device="cpu", +) ``` -large - https://storage.googleapis.com/aria-checkpoints/large-abs-pt.bin -medium - https://storage.googleapis.com/aria-checkpoints/medium-abs-pt.bin -small - https://storage.googleapis.com/aria-checkpoints/small-abs-pt.bin -``` -You can then sample using the cli: +Our embedding model was trained to capture composition-level and performance-level attributes, and therefore might not be appropriate for every use case. + +## Real-time demo + +In `demo/` we provide CUDA (Linux/PyTorch) and MLX (Apple Silicon) implementations of the real-time interactive piano-continuation demo showcased in our release blog post. For the demo we used an acoustic Yamaha Disklavier piano with simultaneous MIDI input and output ports connected via a standard MIDI interface. + +❗**NOTE**: Responsiveness of the real-time demo is dependent on your system configuration, e.g., GPU FLOPS and memory bandwidth. +A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI, ALSA) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output. + +Example usage (MLX): + +```bash +MIDI_PATH="example-prompts/pokey_jazz.mid" + +python demo/demo_mlx.py \ + --checkpoint \ + --midi_path ${MIDI_PATH} \ + --midi_through \ + --midi_out \ + --save_path \ + --temp 0.98 \ + --min_p 0.035 ``` -aria sample \ - -m large \ - -c \ - -p \ - -var \ - -trunc \ - -l \ - -temp 0.95 \ - -e + +## Evaluation + +We provide the specific files/splits we used for Aria-MIDI derived linear-probe and classification evaluations. These can be downloaded from HuggingFace ([direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/eval-splits.tar.gz?download=true)). Class labels are provided in `metadata.json` with the schema: + +```json +{ + "": { + "": { + "": "", + ... + }, + ... + }, + ... +} ``` -You can use `aria sample -h` to see a full list of options. If you wish to sample from a pretrained checkpoint, please use the `-pt` flag. +## License and Attribution +The Aria project has been kindly supported by EleutherAI, Stability AI, as well as by a compute grant from the Ministry of Science and ICT of Korea. Our models and MIDI tooling are released under the Apache-2.0 license. If you use the models or tooling for follow-up work, please cite the paper in which they were introduced: +```bibtex +@inproceedings{bradshawscaling, + title={Scaling Self-Supervised Representation Learning for Symbolic Piano Performance}, + author={Bradshaw, Louis and Fan, Honglu and Spangher, Alex and Biderman, Stella and Colton, Simon}, + booktitle={arXiv preprint}, + year={2025}, + url={https://arxiv.org/abs/2504.15071} +} +``` \ No newline at end of file diff --git a/aria/datasets.py b/aria/datasets.py index b1d3b510..1f79c6cf 100644 --- a/aria/datasets.py +++ b/aria/datasets.py @@ -16,17 +16,14 @@ from mido.midifiles.units import second2tick from pathlib import Path from typing import List -from copy import deepcopy from typing import Callable, Iterable from collections import defaultdict from aria.config import load_config -from aria.tokenizer import InferenceAbsTokenizer from ariautils.tokenizer import Tokenizer from ariautils.midi import ( MidiDict, get_test_fn, - get_duration_ms, get_metadata_fn, ) @@ -330,7 +327,8 @@ def _preprocess_mididict(_mid_dict: MidiDict): def build_mididict_dataset( - dir: str, + dir: str | None = None, + mid_paths: list[str] = [], recur: bool = False, stream_save_path: str = None, overwrite: bool = False, @@ -398,13 +396,16 @@ def _get_mididicts_mp(_paths): "will slow down dataset building" ) - paths = [] - if recur is True: - paths += Path(dir).rglob(f"*.mid") - paths += Path(dir).rglob(f"*.midi") - else: - paths += Path(dir).glob(f"*.mid") - paths += Path(dir).glob(f"*.midi") + assert mid_paths or dir, "Must provider paths or a directory to glob files" + + paths = mid_paths if mid_paths else [] + if dir is not None: + if recur is True: + paths += Path(dir).rglob(f"*.mid") + paths += Path(dir).rglob(f"*.midi") + else: + paths += Path(dir).glob(f"*.mid") + paths += Path(dir).glob(f"*.midi") num_paths = len(paths) if num_paths == 0: @@ -473,7 +474,7 @@ def __init__(self, tokenizer: Tokenizer): def build(**kwargs): raise NotImplementedError - def get_loss_mask(self, src_seq: list, tgt_seq: list): + def get_loss_mask(self, src_seq: list, tgt_seq: list, offset: int = 0): # Should returns a bool Tensor with False indicating a masked loss raise NotImplementedError @@ -594,8 +595,10 @@ def _format(tok): mmap_obj = self.file_mmaps[file_idx] mmap_obj.seek(pos) - _debug = mmap_obj.readline() - seq = json.loads(_debug) # Load raw seq + entry_raw = mmap_obj.readline() + entry_dict = json.loads(entry_raw) + + seq = entry_dict["seq"] # Load raw seq seq = [_format(tok) for tok in seq] # Format into hashable if self._transform: seq = self._transform(seq) # Data augmentation @@ -603,11 +606,13 @@ def _format(tok): src = seq tgt = seq[1:] + [self.tokenizer.pad_tok] mask = self.get_loss_mask(src_seq=src, tgt_seq=tgt) + emb = entry_dict.get("emb", None) return ( torch.tensor(self.tokenizer.encode(src)), torch.tensor(self.tokenizer.encode(tgt)), mask, + torch.tensor(emb) if emb is not None else torch.empty(0), ) def check_config(self, epoch_load_path: str): @@ -708,6 +713,8 @@ def _get_seqs( else: raise Exception + _file_path = _midi_dict.metadata["abs_load_path"] + try: if _tokenize_fn is not None: _tokenized_seq = _tokenize_fn(_midi_dict) @@ -716,11 +723,13 @@ def _get_seqs( except Exception as e: print(e) logger.info(f"Skipping midi_dict: {e}") + return else: if _tokenizer.unk_tok in _tokenized_seq: logger.warning("Unknown token seen while tokenizing midi_dict") - return _tokenized_seq + + return _tokenized_seq, _file_path def get_seqs( @@ -728,7 +737,7 @@ def get_seqs( midi_dict_iter: Iterable, tokenize_fn: Callable | None = None, ): - # Can't pickle geneator object when start method is spawn + # Can't pickle generator object when start method is spawn if multiprocessing.get_start_method() == "spawn": logging.info( "Converting generator to list due to multiprocessing start method" @@ -777,6 +786,7 @@ def random_selection_itt(iterables: list[Iterable]): pass +# GOAL: Modify this and then rename it, removing ft-dataset from codebase class PretrainingDataset(TrainingDataset): """Torch dataset object yielding sequences formatted for pre-training""" @@ -808,12 +818,14 @@ def build( num_epochs: int, midi_dataset: MidiDataset = None, midi_dataset_path: str = None, + separate_sequences: bool = False, + file_embeddings: dict | None = None, ): """Builds and returns PretrainingDataset.""" - def _build_epoch(_save_path, _midi_dataset): + def _build_concat_epoch(_save_path: str, _midi_dataset: Iterable): + # Sequences are concatenated and sliced with jsonlines.open(_save_path, mode="w") as writer: - # Write tokenizer info into json on first line writer.write( { "tokenizer_config": tokenizer.config, @@ -821,22 +833,75 @@ def _build_epoch(_save_path, _midi_dataset): "max_seq_len": max_seq_len, } ) - - buffer = [] + seq_buffer = [] _idx = 0 - for entry in reservoir(get_seqs(tokenizer, _midi_dataset), 10): - if entry is not None: - buffer += entry - while len(buffer) >= max_seq_len: - writer.write(buffer[:max_seq_len]) - buffer = buffer[max_seq_len:] + for entry, file_path in reservoir( + get_seqs(tokenizer, _midi_dataset), 10 + ): + seq_buffer += entry + + while len(seq_buffer) >= max_seq_len: + writer.write({"seq": seq_buffer[:max_seq_len]}) + seq_buffer = seq_buffer[max_seq_len:] _idx += 1 if _idx % 250 == 0: logger.info(f"Finished processing {_idx}") - buffer += [tokenizer.pad_tok] * (max_seq_len - len(buffer)) - writer.write(buffer[:max_seq_len]) + if seq_buffer: + seq_buffer += [tokenizer.pad_tok] * ( + max_seq_len - len(seq_buffer) + ) + writer.write({"seq": seq_buffer[:max_seq_len]}) + + def _build_epoch_separated( + _save_path: str, + _midi_dataset: Iterable, + _file_embeddings: dict | None, + ): + # Sequences always start with a new entry (requires padding) + with jsonlines.open(_save_path, mode="w") as writer: + writer.write( + { + "tokenizer_config": tokenizer.config, + "tokenizer_name": tokenizer.name, + "max_seq_len": max_seq_len, + } + ) + _idx = 0 + for entry, file_path in reservoir( + get_seqs(tokenizer, _midi_dataset), 10 + ): + seq_buffer = entry + embedding_data = ( + {"emb": _file_embeddings[file_path]} + if _file_embeddings + else {} + ) + + while len(seq_buffer) >= max_seq_len: + writer.write( + { + "seq": seq_buffer[:max_seq_len], + **embedding_data, + } + ) + seq_buffer = seq_buffer[max_seq_len:] + + if seq_buffer: + seq_buffer += [tokenizer.pad_tok] * ( + max_seq_len - len(seq_buffer) + ) + writer.write( + { + "seq": seq_buffer[:max_seq_len], + **embedding_data, + } + ) + + _idx += 1 + if _idx % 250 == 0: + logger.info(f"Finished processing {_idx}") logger = setup_logger() assert max_seq_len > 0, "max_seq_len must be greater than 0" @@ -878,114 +943,24 @@ def _build_epoch(_save_path, _midi_dataset): if midi_dataset_path: midi_dataset = MidiDataset.get_generator(midi_dataset_path) - _build_epoch( - _save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"), - _midi_dataset=midi_dataset, - ) + if separate_sequences is True: + _build_epoch_separated( + _save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"), + _midi_dataset=midi_dataset, + _file_embeddings=file_embeddings, + ) + else: + _build_concat_epoch( + _save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"), + _midi_dataset=midi_dataset, + ) logger.info( f"Finished building, saved PretrainingDataset to {save_dir}" ) -# TODO: Refactor for readability -def _get_combined_mididict( - clean_midi_dict: MidiDict, - noisy_midi_dict: MidiDict, - min_noisy_ms: int, - max_noisy_ms: int, - min_clean_ms: int, - max_clean_ms: int, -) -> MidiDict: - # NOTE: We adopt the tempo/ticks_per_beat of the clean_midi_dict, and - # adjust the noisy note messages accordingly. - assert len(clean_midi_dict.tempo_msgs) == 1, "Unsupported tempo msgs" - assert len(noisy_midi_dict.tempo_msgs) == 1, "Unsupported tempo msgs" - - total_length_ms = get_duration_ms( - start_tick=0, - end_tick=clean_midi_dict.note_msgs[-1]["data"]["start"], - tempo_msgs=clean_midi_dict.tempo_msgs, - ticks_per_beat=clean_midi_dict.ticks_per_beat, - ) - - # Create intervals - noisy_intervals = [] - clean_intervals = [] - prev_ms = -1 - add_noisy_next = random.choice([True, False]) - while True: - if add_noisy_next is True: - # Add noisy interval - noisy_end_ms = random.randint( - prev_ms + min_noisy_ms, prev_ms + max_noisy_ms - ) - noisy_intervals.append([prev_ms + 1, noisy_end_ms]) - prev_ms = noisy_end_ms - if prev_ms > total_length_ms: - break - else: - add_noisy_next = False - else: - # Add clean interval - clean_end_ms = random.randint( - prev_ms + min_clean_ms, prev_ms + max_clean_ms - ) - clean_intervals.append([prev_ms + 1, clean_end_ms]) - prev_ms = clean_end_ms - if prev_ms > total_length_ms: - break - else: - add_noisy_next = True - - # Merge note_msgs - clean_ms_to_tick = (clean_midi_dict.ticks_per_beat * 1e3) / ( - clean_midi_dict.tempo_msgs[0]["data"] - ) - - comb_note_msgs = [] - for _note_msg in noisy_midi_dict.note_msgs: - onset_time_ms = noisy_midi_dict.tick_to_ms(_note_msg["data"]["start"]) - - for _interval_start_ms, _interval_end_ms in noisy_intervals: - if _interval_start_ms < onset_time_ms < _interval_end_ms: - offset_time_ms = noisy_midi_dict.tick_to_ms( - _note_msg["data"]["end"] - ) - _adj_note_msg = copy.deepcopy(_note_msg) - _adj_onset_tick = int(onset_time_ms * clean_ms_to_tick) - _adj_offset_tick = int(offset_time_ms * clean_ms_to_tick) - _adj_note_msg["tick"] = _adj_onset_tick - _adj_note_msg["data"]["start"] = _adj_onset_tick - _adj_note_msg["data"]["end"] = _adj_offset_tick - - comb_note_msgs.append(_adj_note_msg) - break - - for _note_msg in clean_midi_dict.note_msgs: - onset_time_ms = clean_midi_dict.tick_to_ms(_note_msg["data"]["start"]) - - for _interval_start_ms, _interval_end_ms in clean_intervals: - if _interval_start_ms < onset_time_ms < _interval_end_ms: - comb_note_msgs.append(_note_msg) - break - - comb_metadata = deepcopy(clean_midi_dict.metadata) - comb_metadata["noisy_intervals"] = noisy_intervals - - # Maybe using clean pedal msgs here is bad? - return MidiDict( - meta_msgs=clean_midi_dict.meta_msgs, - tempo_msgs=clean_midi_dict.tempo_msgs, - pedal_msgs=clean_midi_dict.pedal_msgs, - instrument_msgs=clean_midi_dict.instrument_msgs, - note_msgs=comb_note_msgs, - ticks_per_beat=clean_midi_dict.ticks_per_beat, - metadata=comb_metadata, - ) - - -# TODO: Refactor this function for readability +# Unused but potentially useful in the future def _noise_midi_dict(midi_dict: MidiDict, config: dict): def _get_velocity_adjusted_msg( __note_msg: dict, @@ -1146,168 +1121,3 @@ def _get_onset_adjusted_msg( ticks_per_beat=midi_dict.ticks_per_beat, metadata=midi_dict.metadata, ) - - -def export_inference_abs_build_tokenize_fn( - midi_dict: MidiDict, tokenizer: InferenceAbsTokenizer -): - finetuning_config = load_config()["data"]["finetuning"] - GUIDANCE_PROB = finetuning_config["guidance_prob"] - NOISING_PROB = finetuning_config["noising"]["activation_prob"] - MIN_NOISY_MS = finetuning_config["min_noisy_interval_ms"] - MAX_NOISY_MS = finetuning_config["max_noisy_interval_ms"] - MIN_CLEAN_MS = finetuning_config["min_clean_interval_ms"] - MAX_CLEAN_MS = finetuning_config["max_clean_interval_ms"] - - if random.random() <= NOISING_PROB: - noisy_midi_dict = _noise_midi_dict( - midi_dict, config=finetuning_config["noising"] - ) - midi_dict_for_tokenization = _get_combined_mididict( - clean_midi_dict=midi_dict, - noisy_midi_dict=noisy_midi_dict, - min_noisy_ms=MIN_NOISY_MS, - max_noisy_ms=MAX_NOISY_MS, - min_clean_ms=MIN_CLEAN_MS, - max_clean_ms=MAX_CLEAN_MS, - ) - else: - midi_dict_for_tokenization = midi_dict - - if random.random() <= GUIDANCE_PROB: - return tokenizer.tokenize( - midi_dict=midi_dict_for_tokenization, - prompt_intervals_ms=midi_dict_for_tokenization.metadata.get( - "noisy_intervals", [] - ), - guidance_midi_dict=midi_dict, - ) - else: - return tokenizer.tokenize( - midi_dict=midi_dict_for_tokenization, - prompt_intervals_ms=midi_dict_for_tokenization.metadata.get( - "noisy_intervals", [] - ), - ) - - -class FinetuningDataset(TrainingDataset): - """Torch dataset object yielding sequences formatted for fine-tuning.""" - - def __init__( - self, dir_paths: List[str] | str, tokenizer: InferenceAbsTokenizer - ): - super().__init__(tokenizer=tokenizer) - - assert tokenizer.name == "inference_abs", "invalid tokenizer" - - if isinstance(dir_paths, str): - dir_paths = [dir_paths] - - self.dir_paths = dir_paths - self.get_epoch_files_by_dir(dir_paths) - self.init_epoch(0) - - def __len__(self): - return len(self.index) - - def get_loss_mask(self, src_seq: list, tgt_seq: list): - mask = [False] * len(tgt_seq) - inside_target = True - - for idx, (src_tok, tgt_tok) in enumerate(zip(src_seq, tgt_seq)): - if src_tok == self.tokenizer.guidance_start_tok: - inside_target = False - elif src_tok == self.tokenizer.guidance_end_tok: - inside_target = True - elif tgt_tok == self.tokenizer.prompt_start_tok: - inside_target = False - elif src_tok == self.tokenizer.prompt_end_tok: - inside_target = True - - if inside_target is True and tgt_tok != self.tokenizer.pad_tok: - mask[idx] = True - - return torch.tensor(mask, dtype=torch.bool) - - @classmethod - def build( - cls, - tokenizer: InferenceAbsTokenizer, - save_dir: str, - max_seq_len: int, - num_epochs: int, - midi_dataset_path: str, - ): - - def _build_epoch(_save_path, _midi_dataset): - with jsonlines.open(_save_path, mode="w") as writer: - # Write tokenizer info into json on first line - writer.write( - { - "tokenizer_config": tokenizer.config, - "tokenizer_name": tokenizer.name, - "max_seq_len": max_seq_len, - } - ) - - _idx = 0 - for entry in reservoir( - get_seqs( - tokenizer, - _midi_dataset, - tokenize_fn=functools.partial( - export_inference_abs_build_tokenize_fn, - tokenizer=tokenizer, - ), - ), - 10, - ): - for _entry in tokenizer.split(entry, max_seq_len): - writer.write(_entry) - - _idx += 1 - if _idx % 250 == 0: - logger.info(f"Finished processing {_idx}") - - logger = setup_logger() - assert max_seq_len > 0, "max_seq_len must be greater than 0" - assert num_epochs > 0, "num_epochs must be greater than 0" - assert os.path.isfile(midi_dataset_path), "file not found" - if multiprocessing.get_start_method() == "spawn": - logger.warning( - 'The current multiprocessing start method is "spawn", this ' - "will slow down dataset building" - ) - - if os.path.isdir(save_dir) and os.listdir(save_dir): - print( - f"The directory at {save_dir} in non-empty, type [Y/y] to " - "remove and continue:" - ) - if input() not in {"Y", "y"}: - print("Aborting") - return - else: - shutil.rmtree(save_dir) - - if not os.path.exists(save_dir): - os.mkdir(save_dir) - - logger.info( - f"Building FinetuningDataset with config: " - f"max_seq_len={max_seq_len}, " - f"tokenizer_name={tokenizer.name}" - ) - - for idx in range(num_epochs): - logger.info(f"Building epoch {idx}/{num_epochs - 1}...") - - # Reload the combined dataset for each epoch - midi_dataset = MidiDataset.get_generator(midi_dataset_path) - _build_epoch( - _save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"), - _midi_dataset=midi_dataset, - ) - - logger.info(f"Finished building, saved FinetuningDataset to {save_dir}") diff --git a/aria/embedding.py b/aria/embedding.py new file mode 100644 index 00000000..030bcb19 --- /dev/null +++ b/aria/embedding.py @@ -0,0 +1,91 @@ +import torch +import copy + +from ariautils.midi import MidiDict +from ariautils.tokenizer import AbsTokenizer +from ariautils.tokenizer._base import Token + +from aria.model import TransformerEMB + +MAX_EMBEDDING_SEQ_LEN = 2048 + + +def _validate_midi_for_emb(midi_dict: MidiDict): + present_instruments = { + midi_dict.program_to_instrument[msg["data"]] + for msg in midi_dict.instrument_msgs + } + assert present_instruments == {"piano"}, "Only piano MIDIs supported" + assert len(midi_dict.note_msgs) > 0 + + +def _get_chunks(midi_dict: MidiDict, notes_per_chunk: int): + res = [] + + for note_msg_chunk in [ + midi_dict.note_msgs[idx : idx + notes_per_chunk] + for idx in range(0, len(midi_dict.note_msgs), notes_per_chunk) + ]: + if len(note_msg_chunk) == 0: + break + + chunked_midi_dict = copy.deepcopy(midi_dict) + chunked_midi_dict.note_msgs = note_msg_chunk + chunked_midi_dict.metadata = {} + res.append(chunked_midi_dict) + + return res + + +@torch.inference_mode() +def get_embedding_from_seq( + model: TransformerEMB, seq: list[Token], device="cuda" +): + tokenizer = AbsTokenizer() + + assert len(seq) <= MAX_EMBEDDING_SEQ_LEN, f"Sequence lengths above {MAX_EMBEDDING_SEQ_LEN} not supported" # fmt: skip + _validate_midi_for_emb(tokenizer.detokenize(seq)) + + model.eval() + eos_pos = seq.index(tokenizer.eos_tok) + seq_enc = torch.tensor(tokenizer.encode(seq), device=device) + emb = model.forward(seq_enc.view(1, -1))[0, eos_pos] + + return emb + + +# TODO: Make sure this is bug free +def get_global_embedding_from_midi( + model: TransformerEMB, + midi_dict: MidiDict | None = None, + midi_path: str | None = None, + notes_per_chunk: int = 300, + device="cuda", +): + """Calculates global contrastive embedding by calculating an unweighted + average of chunk embeddings of notes_per_chunk notes.""" + + assert midi_dict or midi_path + + if midi_path: + midi_dict = MidiDict.from_midi(mid_path=midi_path) + + tokenizer = AbsTokenizer() + _validate_midi_for_emb(midi_dict) + + chunks = _get_chunks(midi_dict=midi_dict, notes_per_chunk=notes_per_chunk) + seqs = [ + tokenizer.tokenize(c, add_dim_tok=False)[:MAX_EMBEDDING_SEQ_LEN] + for c in chunks + ] + + # Add back eos_tok if truncated by MAX_EMBEDDING_SEQ_LEN + for seq in seqs: + if seq[-1] != tokenizer.eos_tok: + seq[-1] = tokenizer.eos_tok + + embs = [ + get_embedding_from_seq(model=model, seq=s, device=device) for s in seqs + ] + + return torch.mean(torch.stack(embs), dim=0) diff --git a/tests/__init__.py b/aria/eval/__init__.py similarity index 100% rename from tests/__init__.py rename to aria/eval/__init__.py diff --git a/aria/eval/linear_probe.py b/aria/eval/linear_probe.py new file mode 100644 index 00000000..0766a74a --- /dev/null +++ b/aria/eval/linear_probe.py @@ -0,0 +1,721 @@ +import torch +import accelerate +import os +import mmap +import json +import time +import functools +import multiprocessing +import queue +import copy +import jsonlines +import torch.nn as nn +import torch.nn.functional as F + +from tqdm import tqdm +from typing import Callable +from concurrent.futures import ThreadPoolExecutor + +from ariautils.midi import MidiDict +from ariautils.tokenizer import AbsTokenizer + +CATEGORY_TAGS = { + "genre": { + "classical": 0, + "jazz": 1, + }, + "music_period": { + "baroque": 0, + "classical": 1, + "romantic": 2, + "impressionist": 3, + }, + "composer": { + "beethoven": 0, + "debussy": 1, + "brahms": 2, + "rachmaninoff": 3, + "schumann": 4, + "mozart": 5, + "liszt": 6, + "bach": 7, + "chopin": 8, + "schubert": 9, + }, + "form": { + "nocturne": 0, + "sonata": 1, + "improvisation": 2, + "etude": 3, + "fugue": 4, + "waltz": 5, + }, + "pianist": { + "hisaishi": 0, + "hancock": 1, + "bethel": 2, + "einaudi": 3, + "clayderman": 4, + "ryuichi": 5, + "yiruma": 6, + "hillsong": 7, + }, + "emotion": { + "happy": 0, + "sad": 1, + "calm": 2, + "tense": 3, + }, +} +LEARNING_RATE = 3e-4 + + +def model_forward( + model: nn.Module, + idxs: torch.Tensor, +): + return model(idxs) + + +def write_entries(writer, entries): + for entry in entries: + writer.write(entry) + + +def get_chunks(note_msgs: list, chunk_len: int): + return [ + note_msgs[i : i + chunk_len] + for i in range(0, len(note_msgs), chunk_len) + ] + + +def process_entry( + entry: MidiDict | dict, + slice_len_notes: int, + max_seq_len: int, + tokenizer: AbsTokenizer, +): + if isinstance(entry, dict): + midi_dict = MidiDict.from_msg_dict(entry) + else: + midi_dict = entry + + outputs = [] + for slice_note_msgs in get_chunks( + note_msgs=midi_dict.note_msgs, chunk_len=slice_len_notes + ): + if len(slice_note_msgs) == 0: + break + + slice_midi_dict = copy.deepcopy(midi_dict) + slice_midi_dict.note_msgs = slice_note_msgs + slice_midi_dict.metadata = {} + tokenized_slice = tokenizer.tokenize(slice_midi_dict) + if tokenizer.dim_tok in tokenized_slice: + tokenized_slice.remove(tokenizer.dim_tok) + + tokenized_slice = tokenized_slice[:max_seq_len] + + outputs.append({"seq": tokenized_slice, "metadata": midi_dict.metadata}) + + return outputs + + +def _pad_seq(seq: list, tokenizer: AbsTokenizer, max_seq_len: int): + seq = seq[:max_seq_len] + seq += [tokenizer.pad_tok] * (max_seq_len - len(seq)) + + if tokenizer.eos_tok not in seq: + seq[-1] = tokenizer.eos_tok + + return seq + + +@torch.autocast( + "cuda", + dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, +) +@torch.inference_mode() +def get_aria_contrastive_embedding( + seqs: list, + hook_model: nn.Module, + hook_max_seq_len: int, + hook_tokenizer: AbsTokenizer, + hook_model_forward: Callable, + hook_max_batch_size: int = 64, +): + all_emb = [] + + for i in range(0, len(seqs), hook_max_batch_size): + batch_seqs = seqs[i : i + hook_max_batch_size] + padded_seqs = [ + _pad_seq( + seq=seq, tokenizer=hook_tokenizer, max_seq_len=hook_max_seq_len + ) + for seq in batch_seqs + ] + eos_positions = [ + seq.index(hook_tokenizer.eos_tok) for seq in padded_seqs + ] + enc_seqs = torch.tensor( + [hook_tokenizer.encode(seq) for seq in padded_seqs], device="cuda" + ) + hidden_states = hook_model_forward(model=hook_model, idxs=enc_seqs) + idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) + batch_emb = hidden_states[idx, eos_positions].tolist() + all_emb.extend(batch_emb) + + return all_emb + + +def get_mert_embedding( + seqs: list, + hook_model: nn.Module, + hook_processor, + hook_tokenizer: AbsTokenizer, + hook_pianoteq_exec_path: str, + hook_pianoteq_num_procs: int, +): + from aria.eval.mert.emb import ( + seq_to_audio_path, + compute_audio_embedding, + ) + + with multiprocessing.Pool(hook_pianoteq_num_procs) as pool: + audio_paths = pool.imap( + functools.partial( + seq_to_audio_path, + tokenizer=hook_tokenizer, + pianoteq_exec_path=hook_pianoteq_exec_path, + ), + seqs, + ) + + emb = [ + compute_audio_embedding( + audio_path=path, + model=hook_model, + processor=hook_processor, + delete_audio=True, + ).tolist() + for path in audio_paths + ] + + return emb + + +def get_clamp3_embedding( + seqs: list, + hook_model: nn.Module, + hook_patchilizer, + hook_tokenizer: AbsTokenizer, +): + from aria.eval.m3.emb import get_midi_embedding + + emb = [ + get_midi_embedding( + mid=hook_tokenizer.detokenize(seq).to_midi(), + model=hook_model, + patchilizer=hook_patchilizer, + get_global=True, + ).tolist() + for seq in seqs + ] + + return emb + + +@torch.autocast("cuda", dtype=torch.bfloat16) +@torch.inference_mode() +def get_baseline_embedding( + seqs: list, + hook_model: nn.Module, + hook_max_seq_len: int, + hook_tokenizer: AbsTokenizer, + pool_mode: str = "last", # "last" or "mean" +): + for seq in seqs: + if hook_tokenizer.eos_tok in seq: + seq.remove(hook_tokenizer.eos_tok) + + orig_lengths = [len(seq) for seq in seqs] + last_tok_positions = [length - 1 for length in orig_lengths] + seqs = [ + seq + ([hook_tokenizer.pad_tok] * (hook_max_seq_len - len(seq))) + for seq in seqs + ] + + enc_seqs = torch.tensor( + [hook_tokenizer.encode(seq) for seq in seqs], device="cuda" + ) + hidden_states = hook_model(enc_seqs) + + if pool_mode == "last": + idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) + emb = hidden_states[idx, last_tok_positions].tolist() + elif pool_mode == "mean": + pad_id = hook_tokenizer.pad_id + # Create a mask by comparing enc_seqs to pad_id + mask = (enc_seqs != pad_id).unsqueeze(-1).to(hidden_states.dtype) + # Sum over valid tokens and average + sum_hidden = (hidden_states * mask).sum(dim=1) + valid_counts = mask.sum(dim=1) + mean_hidden = sum_hidden / valid_counts + emb = mean_hidden.tolist() + else: + raise ValueError(f"Unsupported pool_mode: {pool_mode}") + + return emb + + +class EvaluationDataset(torch.utils.data.Dataset): + def __init__(self, load_path: str, tag_to_id: dict, metadata_category: str): + self.load_path = load_path + self.tag_to_id = tag_to_id + self.metadata_category = metadata_category + self.tokenizer = AbsTokenizer() + self.index = [] + + self.file_buff = open(self.load_path, "rb") + self.mmap_obj = mmap.mmap( + self.file_buff.fileno(), 0, access=mmap.ACCESS_READ + ) + + while True: + pos = self.mmap_obj.tell() + line = self.mmap_obj.readline() + if not line: + break + self.index.append(pos) + + def __getitem__(self, idx: int): + pos = self.index[idx] + self.mmap_obj.seek(pos) + + raw_data = self.mmap_obj.readline().decode("utf-8") + json_data = json.loads(raw_data) + + emb = json_data["emb"] + metadata = json_data["metadata"] + tag = metadata.get(self.metadata_category, "other") + tag = tag if tag in self.tag_to_id.keys() else "other" + + assert tag in self.tag_to_id, metadata + tag_tensor = torch.tensor(self.tag_to_id[tag]) + emb_tensor = torch.tensor(emb) + + return emb_tensor, tag_tensor + + def __len__(self): + return len(self.index) + + @classmethod + def export_worker_init_fn(cls): + def worker_init_fn(worker_id: int): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + + if hasattr(dataset, "mmap_obj") and dataset.mmap_obj: + dataset.mmap_obj.close() + + f = open(dataset.load_path, "rb") + dataset.mmap_obj = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + + return worker_init_fn + + @classmethod + def build( + cls, + midi_dataset_load_path: str, + save_path: str, + slice_len_notes: int, + max_seq_len: int, + batch_size: int, + embedding_hook: Callable, + per_file_embeddings: bool = False, + **embedding_hook_kwargs, + ): + def batch_producer( + results_queue: queue.Queue, + batch_queue: queue.Queue, + batch_size: int, + total_workers: int, + per_file: bool = False, + ): + buffer = [] + termination_signals = 0 + + while termination_signals < total_workers: + if batch_queue.qsize() > 10: + time.sleep(0.25) + + try: + result = results_queue.get(timeout=0.01) + except queue.Empty: + continue + if result is None: + termination_signals += 1 + continue + + if per_file: + assert all( + "abs_load_path" in r["metadata"].keys() for r in result + ) + buffer.extend(result) + if len(buffer) > 2 * batch_size: + print( + f"WARNING: Generated batch of size {len(buffer)} (batch_size={batch_size})" + ) + if len(buffer) >= batch_size: + batch_queue.put(buffer) + buffer = [] + else: + buffer.extend(result) + while len(buffer) >= batch_size: + batch_queue.put(buffer[:batch_size]) + buffer = buffer[batch_size:] + + if buffer: + batch_queue.put(buffer) + + def producer( + midi_dataset_load_path: str, + midi_dict_queue: queue.Queue, + num_workers: int, + ): + cnt = 0 + with jsonlines.open(midi_dataset_load_path, "r") as midi_dataset: + for midi_dict in midi_dataset: + while midi_dict_queue.qsize() >= 1000: + time.sleep(0.1) + midi_dict_queue.put(midi_dict) + cnt += 1 + + if cnt % 500 == 0: + print(f"Finished {cnt}") + + for _ in range(num_workers): + midi_dict_queue.put(None) + + def worker( + midi_dict_queue: queue.Queue, + results_queue: queue.Queue, + slice_len_notes: int, + max_seq_len: int, + ): + tokenizer = AbsTokenizer() + + while True: + midi_dict = midi_dict_queue.get() + if midi_dict is None: + results_queue.put(None) + break + + while results_queue.qsize() > 250: + time.sleep(0.5) + + _result = process_entry( + entry=midi_dict, + slice_len_notes=slice_len_notes, + max_seq_len=max_seq_len, + tokenizer=tokenizer, + ) + results_queue.put(_result) + + assert os.path.isfile(midi_dataset_load_path) + assert os.path.isfile(save_path) is False + + TOTAL_WORKERS = 8 + write_executor = ThreadPoolExecutor(max_workers=1) + results_queue = multiprocessing.Queue() + midi_dict_queue = multiprocessing.Queue() + batch_queue = multiprocessing.Queue() + producer_process = multiprocessing.Process( + target=producer, + args=(midi_dataset_load_path, midi_dict_queue, TOTAL_WORKERS), + ) + batch_producer_process = multiprocessing.Process( + target=batch_producer, + args=( + results_queue, + batch_queue, + batch_size, + TOTAL_WORKERS, + per_file_embeddings, + ), + ) + worker_processes = [ + multiprocessing.Process( + target=worker, + args=( + midi_dict_queue, + results_queue, + slice_len_notes, + max_seq_len, + ), + ) + for _ in range(TOTAL_WORKERS) + ] + + producer_process.start() + batch_producer_process.start() + for p in worker_processes: + p.start() + + with jsonlines.open(save_path, "w") as writer: + while batch_producer_process.is_alive() or not batch_queue.empty(): + try: + batch = batch_queue.get(timeout=0.01) + + _seqs = [item["seq"] for item in batch] + _metadata = [item["metadata"] for item in batch] + _embs = embedding_hook(seqs=_seqs, **embedding_hook_kwargs) + + if not per_file_embeddings: + write_objs = [ + {"seq": s, "emb": e, "metadata": m} + for s, e, m in zip(_seqs, _embs, _metadata) + ] + else: + # Calculate per-file emb by averaging over abs_load_path embs + groups = {} + for seq, emb, meta in zip(_seqs, _embs, _metadata): + file_path = meta["abs_load_path"] + if file_path not in groups: + groups[file_path] = { + "seqs": [], + "embs": [], + "metadata": meta, + } + groups[file_path]["seqs"].append(seq) + groups[file_path]["embs"].append(emb) + + write_objs = [] + for file_path, data in groups.items(): + avg_emb = ( + torch.tensor(data["embs"]).mean(dim=0).tolist() + ) + write_objs.append( + { + "seqs": data["seqs"], + "emb": avg_emb, + "metadata": data["metadata"], + } + ) + + write_executor.submit(write_entries, writer, write_objs) + + except queue.Empty: + continue + + write_executor.shutdown(wait=True) + + +def _get_optim( + model: nn.Module, + total_steps: int, + warmup: int = 100, + end_ratio: int = 0.1, +): + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LEARNING_RATE, + weight_decay=0.1, + betas=(0.9, 0.95), + eps=1e-5, + ) + + warmup_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.000001, + end_factor=1, + total_iters=warmup, + ) + linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1, + end_factor=end_ratio, + total_iters=total_steps - warmup, + ) + + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_lrs, linear_decay_lrs], + milestones=[warmup], + ) + + return optimizer, lr_scheduler + + +class ClassifierHead(nn.Module): + def __init__(self, d_emb: int, num_class: int): + super().__init__() + self.linear = nn.Linear(d_emb, num_class) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +def _train( + accelerator: accelerate.Accelerator, + model: nn.Module, + train_dataloader: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + num_epochs: int = 1, +): + TRAILING_LOSS_STEPS = 100 + loss = torch.tensor([0.0]) + trailing_loss = 0 + lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) + loss_buffer = [] + + model.train() + loss_fn = nn.CrossEntropyLoss() + + for _epoch in range(num_epochs): + for __step, batch in ( + pbar := tqdm(enumerate(train_dataloader), leave=False) + ): + pbar.set_postfix_str( + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing={round(trailing_loss, 4)}" + ) + + emb, tag_ids = batch + tag_ids = tag_ids.view(-1) + + logits = model(emb) + loss = loss_fn(logits, tag_ids) + + loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) + trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( + loss_buffer[-TRAILING_LOSS_STEPS:] + ) + + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + if scheduler: + scheduler.step() + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + + return model + + +def train_classifier( + embedding_dimension: int, + train_dataset_path: str, + metadata_category: str, + tag_to_id: dict, + batch_size: int, + num_epochs: int = 1, +): + train_dataset = EvaluationDataset( + load_path=train_dataset_path, + tag_to_id=tag_to_id, + metadata_category=metadata_category, + ) + train_dataloader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=24, + worker_init_fn=EvaluationDataset.export_worker_init_fn(), + ) + + model = ClassifierHead( + d_emb=embedding_dimension, + num_class=len(tag_to_id.keys()), + ) + optimizer, scheduler = _get_optim( + model=model, + total_steps=num_epochs * len(train_dataloader), + ) + accelerator = accelerate.Accelerator(cpu=True) + + model, train_dataloader, optimizer, scheduler = accelerator.prepare( + model, + train_dataloader, + optimizer, + scheduler, + ) + + return _train( + accelerator=accelerator, + model=model, + train_dataloader=train_dataloader, + optimizer=optimizer, + scheduler=scheduler, + num_epochs=num_epochs, + ) + + +def evaluate_classifier( + model: nn.Module, + evaluation_dataset_path: str, + metadata_category: str, + tag_to_id: dict, +): + id_to_tag = {v: k for k, v in tag_to_id.items()} + val_dataset = EvaluationDataset( + load_path=evaluation_dataset_path, + tag_to_id=tag_to_id, + metadata_category=metadata_category, + ) + model = model.cpu().eval() + + dist = {k: {"correct": 0, "total": 0} for k in tag_to_id.keys()} + pred_dist = {k: 0 for k in tag_to_id.keys()} + + for midi_emb, tag_id in val_dataset: + with torch.no_grad(): + logits = model(torch.tensor(midi_emb.view(1, -1))) + probs = F.softmax(logits, dim=-1) + pred_tag_id = probs.argmax(dim=-1).item() + + true_tag = id_to_tag[tag_id.item()] + pred_tag = id_to_tag[pred_tag_id] + + dist[true_tag]["total"] += 1 + pred_dist[pred_tag] += 1 + + if pred_tag_id == tag_id.item(): + dist[true_tag]["correct"] += 1 + + total_correct = sum(v["correct"] for v in dist.values()) + total_samples = sum(v["total"] for v in dist.values()) + overall_accuracy = total_correct / total_samples + + class_metrics = {} + f1_scores = [] + for tag in tag_to_id.keys(): + TP = dist[tag]["correct"] + FN = dist[tag]["total"] - TP + FP = pred_dist[tag] - TP + precision = TP / (TP + FP) if (TP + FP) > 0 else 0 + recall = TP / (TP + FN) if (TP + FN) > 0 else 0 + f1 = ( + 2 * precision * recall / (precision + recall) + if (precision + recall) > 0 + else 0 + ) + tag_accuracy = TP / dist[tag]["total"] if dist[tag]["total"] > 0 else 0 + class_metrics[tag] = { + "accuracy": tag_accuracy, + "precision": precision, + "recall": recall, + "F1": f1, + } + f1_scores.append(f1) + + macro_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0 + + results = { + "accuracy": overall_accuracy, + "F1-macro": macro_f1, + "class_wise": class_metrics, + } + + return results diff --git a/models/placeholder.txt b/aria/eval/m3/__init__.py similarity index 100% rename from models/placeholder.txt rename to aria/eval/m3/__init__.py diff --git a/aria/eval/m3/config.py b/aria/eval/m3/config.py new file mode 100644 index 00000000..ff53938c --- /dev/null +++ b/aria/eval/m3/config.py @@ -0,0 +1,106 @@ +EVAL_SPLIT = 0.01 # Fraction of training data used for evaluation +WANDB_KEY = "" # Weights and Biases API key + +# -------------------- Configuration for M3 Training -------------------- +M3_TRAIN_FOLDERS = [ + "" # Directory containing training data for M3 +] + +M3_EVAL_FOLDERS = [ + "" # Directory containing evaluation data for M3 (optional) +] + +PATCH_SIZE = 64 # Size of each patch +PATCH_LENGTH = 512 # Length of the patches +PATCH_NUM_LAYERS = 12 # Number of layers in the encoder +TOKEN_NUM_LAYERS = 3 # Number of layers in the decoder +M3_HIDDEN_SIZE = 768 # Size of the hidden layer + +M3_NUM_EPOCH = 100 # Maximum number of epochs for training +M3_LEARNING_RATE = 1e-4 # Learning rate for the optimizer +M3_BATCH_SIZE = 16 # Batch size per GPU (single card) during training +M3_MASK_RATIO = 0.45 # Ratio of masked elements during training +M3_DETERMINISTIC = True # Ensures deterministic results with random seeds +M3_WANDB_LOG = True # Enable logging to Weights and Biases +M3_LOAD_CKPT = True # Load model weights from a checkpoint if available + +M3_WEIGHTS_PATH = ( + "weights_m3" + + "_h_size_" + + str(M3_HIDDEN_SIZE) + + "_t_layers_" + + str(TOKEN_NUM_LAYERS) + + "_p_layers_" + + str(PATCH_NUM_LAYERS) + + "_p_size_" + + str(PATCH_SIZE) + + "_p_length_" + + str(PATCH_LENGTH) + + "_lr_" + + str(M3_LEARNING_RATE) + + "_batch_" + + str(M3_BATCH_SIZE) + + "_mask_" + + str(M3_MASK_RATIO) + + ".pth" +) # Path to store the model weights +M3_LOGS_PATH = M3_WEIGHTS_PATH.replace("weights", "logs").replace( + "pth", "txt" +) # Path to save training logs + +# -------------------- Configuration for CLaMP3 Training ---------------- +CLAMP3_TRAIN_JSONL = "" # Path to the JSONL file with training data for CLaMP3 +CLAMP3_EVAL_JSONL = "" # Path to the JSONL file with evaluation data for CLaMP3 (optional) + +CLAMP3_HIDDEN_SIZE = 768 # Size of the hidden layer +TEXT_MODEL_NAME = ( + "FacebookAI/xlm-roberta-base" # Name of the pre-trained text model +) +MAX_TEXT_LENGTH = 128 # Maximum allowed length for text input + +AUDIO_HIDDEN_SIZE = 768 # Size of the hidden layer for audio features +AUDIO_NUM_LAYERS = 12 # Number of layers in the audio encoder +MAX_AUDIO_LENGTH = 128 # Maximum allowed length for audio input + +CLAMP3_NUM_EPOCH = 100 # Maximum number of epochs for training +CLAMP3_LEARNING_RATE = 1e-5 # Learning rate for the optimizer +CLAMP3_BATCH_SIZE = 256 # Batch size per GPU (single card) during training +LOGIT_SCALE = 1 # Scaling factor for contrastive loss + +FREEZE_TEXT = ( + False # Freeze the weights of the text model and text projection layer +) +TEXT_DROPOUT = True # Whether to apply dropout during text processing +CLAMP3_DETERMINISTIC = True # Ensures deterministic results with random seeds +CLAMP3_LOAD_M3 = True # Load weights from the M3 model +CLAMP3_WANDB_LOG = True # Enable logging to Weights and Biases +CLAMP3_LOAD_CKPT = True # Load weights from a checkpoint if available +SAVE_EVERY = 5 # Save model weights every SAVE_EVERY epochs + +CLAMP3_WEIGHTS_PATH = ( + "weights_clamp3_saas" + + "_h_size_" + + str(CLAMP3_HIDDEN_SIZE) + + "_t_model_" + + TEXT_MODEL_NAME.replace("/", "_") + + "_t_length_" + + str(MAX_TEXT_LENGTH) + + "_a_size_" + + str(AUDIO_HIDDEN_SIZE) + + "_a_layers_" + + str(AUDIO_NUM_LAYERS) + + "_a_length_" + + str(MAX_AUDIO_LENGTH) + + "_s_size_" + + str(M3_HIDDEN_SIZE) + + "_s_layers_" + + str(PATCH_NUM_LAYERS) + + "_p_size_" + + str(PATCH_SIZE) + + "_p_length_" + + str(PATCH_LENGTH) + + ".pth" +) # Path to store CLaMP3 model weights +CLAMP3_LOGS_PATH = CLAMP3_WEIGHTS_PATH.replace("weights", "logs").replace( + "pth", "txt" +) # Path to save training logs diff --git a/aria/eval/m3/emb.py b/aria/eval/m3/emb.py new file mode 100644 index 00000000..e3af9682 --- /dev/null +++ b/aria/eval/m3/emb.py @@ -0,0 +1,194 @@ +import os +import torch +import mido +from transformers import BertConfig, GPT2Config + +from aria.eval.m3.config import ( + AUDIO_HIDDEN_SIZE, + AUDIO_NUM_LAYERS, + MAX_AUDIO_LENGTH, + M3_HIDDEN_SIZE, + PATCH_NUM_LAYERS, + PATCH_LENGTH, + PATCH_SIZE, + CLAMP3_HIDDEN_SIZE, + TEXT_MODEL_NAME, + TOKEN_NUM_LAYERS, +) + +from aria.eval.m3.utils import CLaMP3Model, M3Patchilizer, M3Model + + +def msg_to_str(msg): + str_msg = "" + for key, value in msg.dict().items(): + str_msg += " " + str(value) + return str_msg.strip().encode("unicode_escape").decode("utf-8") + + +def load_midi( + filename: str | None = None, + mid: mido.MidiFile | None = None, + m3_compatible: bool = True, +): + """ + Load a MIDI file and convert it to MTF format. + """ + + if mid is None: + assert os.path.isfile(filename) + mid = mido.MidiFile(filename) + + msg_list = ["ticks_per_beat " + str(mid.ticks_per_beat)] + + # Merge tracks manually using mido.merge_tracks() + merged = mido.merge_tracks(mid.tracks) + + for msg in merged: + if m3_compatible and msg.is_meta: + if msg.type in [ + "text", + "copyright", + "track_name", + "instrument_name", + "lyrics", + "marker", + "cue_marker", + "device_name", + ]: + continue + str_msg = msg_to_str(msg) + msg_list.append(str_msg) + + return "\n".join(msg_list) + + +def load_clamp3_model(checkpoint_path: str, m3_only: bool = False): + # Create audio and symbolic configurations. + audio_config = BertConfig( + vocab_size=1, + hidden_size=AUDIO_HIDDEN_SIZE, + num_hidden_layers=AUDIO_NUM_LAYERS, + num_attention_heads=AUDIO_HIDDEN_SIZE // 64, + intermediate_size=AUDIO_HIDDEN_SIZE * 4, + max_position_embeddings=MAX_AUDIO_LENGTH, + ) + symbolic_config = BertConfig( + vocab_size=1, + hidden_size=M3_HIDDEN_SIZE, + num_hidden_layers=PATCH_NUM_LAYERS, + num_attention_heads=M3_HIDDEN_SIZE // 64, + intermediate_size=M3_HIDDEN_SIZE * 4, + max_position_embeddings=PATCH_LENGTH, + ) + decoder_config = GPT2Config( + vocab_size=128, + n_positions=PATCH_SIZE, + n_embd=M3_HIDDEN_SIZE, + n_layer=TOKEN_NUM_LAYERS, + n_head=M3_HIDDEN_SIZE // 64, + n_inner=M3_HIDDEN_SIZE * 4, + ) + + model = CLaMP3Model( + audio_config=audio_config, + symbolic_config=symbolic_config, + text_model_name=TEXT_MODEL_NAME, + hidden_size=CLAMP3_HIDDEN_SIZE, + load_m3=True, + ) + model = model.to("cuda") + model.eval() + + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + checkpoint = torch.load( + checkpoint_path, map_location="cuda", weights_only=True + ) + + if m3_only is False: + model.load_state_dict(checkpoint["model"]) + else: + temp_m3_model = M3Model(symbolic_config, decoder_config) + temp_m3_model.load_state_dict(checkpoint["model"]) + model.symbolic_model.load_state_dict(temp_m3_model.encoder.state_dict()) + + patchilizer = M3Patchilizer() + + return model, patchilizer + + +def get_midi_embedding( + mid: mido.MidiFile, + model: CLaMP3Model, + patchilizer: M3Patchilizer, + get_global=True, +): + device = "cuda" + mtf_str = load_midi(mid=mid, m3_compatible=True) + patches = patchilizer.encode(mtf_str, add_special_patches=True) + + token_tensor = torch.tensor(patches, dtype=torch.long).to(device) + + num_tokens = token_tensor.size(0) + segments = [] + seg_weights = [] + for i in range(0, num_tokens, PATCH_LENGTH): + seg = token_tensor[i : i + PATCH_LENGTH] + cur_len = seg.size(0) + segments.append(seg) + seg_weights.append(cur_len) + + if num_tokens > PATCH_LENGTH: + segments[-1] = token_tensor[-PATCH_LENGTH:] + seg_weights[-1] = segments[-1].size(0) + + processed_feats = [] + for seg in segments: + cur_len = seg.size(0) + # Pad the segment if it's shorter than PATCH_LENGTH. + if cur_len < PATCH_LENGTH: + pad = torch.full( + ( + PATCH_LENGTH - cur_len, + token_tensor.size(1), + ), # include PATCH_SIZE dimension + patchilizer.pad_token_id, + dtype=torch.long, + device=device, + ) + seg = torch.cat([seg, pad], dim=0) + seg = seg.unsqueeze(0) # Add batch dimension. + + mask = torch.cat( + [ + torch.ones(cur_len, device=device), + torch.zeros(PATCH_LENGTH - cur_len, device=device), + ], + dim=0, + ).unsqueeze(0) + with torch.no_grad(): + feat = model.get_symbolic_features( + symbolic_inputs=seg, symbolic_masks=mask, get_global=get_global + ) + + if not get_global: + feat = feat[:, : int(mask.sum().item()), :] + processed_feats.append(feat) + + if not get_global: + embedding = torch.cat( + [feat.squeeze(0) for feat in processed_feats], dim=0 + ) + else: + # For a global embedding, compute a weighted average of segment features. + feats = torch.stack( + [feat.squeeze(0) for feat in processed_feats], dim=0 + ) + weights = torch.tensor( + seg_weights, dtype=torch.float, device=device + ).view(-1, 1) + embedding = (feats * weights).sum(dim=0) / weights.sum() + + return embedding.view(-1) diff --git a/aria/eval/m3/utils.py b/aria/eval/m3/utils.py new file mode 100644 index 00000000..0b849644 --- /dev/null +++ b/aria/eval/m3/utils.py @@ -0,0 +1,702 @@ +import re +import os +import math +import torch +import random +from aria.eval.m3.config import * +from unidecode import unidecode +from torch.nn import functional as F +from transformers import ( + AutoModel, + BertModel, + GPT2LMHeadModel, + PreTrainedModel, + GPT2Config, +) + +try: + import torch.distributed.nn + from torch import distributed as dist + + has_distributed = True +except ImportError: + has_distributed = False + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +class ClipLoss(torch.nn.Module): + + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def gather_features( + self, + image_features, + text_features, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + use_horovod=False, + ): + assert ( + has_distributed + ), "torch.distributed did not import correctly, please use a PyTorch version with support." + if use_horovod: + assert hvd is not None, "Please install horovod" + if gather_with_grad: + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + else: + with torch.no_grad(): + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features = list( + all_image_features.chunk(world_size, dim=0) + ) + gathered_text_features = list( + all_text_features.chunk(world_size, dim=0) + ) + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat( + gathered_image_features, dim=0 + ) + all_text_features = torch.cat(gathered_text_features, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_image_features = torch.cat( + torch.distributed.nn.all_gather(image_features), dim=0 + ) + all_text_features = torch.cat( + torch.distributed.nn.all_gather(text_features), dim=0 + ) + else: + gathered_image_features = [ + torch.zeros_like(image_features) for _ in range(world_size) + ] + gathered_text_features = [ + torch.zeros_like(text_features) for _ in range(world_size) + ] + dist.all_gather(gathered_image_features, image_features) + dist.all_gather(gathered_text_features, text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + + return all_image_features, all_text_features + + def get_ground_truth(self, device, num_logits) -> torch.Tensor: + # calculated ground-truth and cache if enabled + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + return labels + + def get_logits(self, image_features, text_features, logit_scale): + if self.world_size > 1: + all_image_features, all_text_features = self.gather_features( + image_features, + text_features, + self.local_loss, + self.gather_with_grad, + self.rank, + self.world_size, + self.use_horovod, + ) + + if self.local_loss: + logits_per_image = ( + logit_scale * image_features @ all_text_features.T + ) + logits_per_text = ( + logit_scale * text_features @ all_image_features.T + ) + else: + logits_per_image = ( + logit_scale * all_image_features @ all_text_features.T + ) + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + return logits_per_image, logits_per_text + + def forward( + self, image_features, text_features, logit_scale, output_dict=False + ): + device = image_features.device + logits_per_image, logits_per_text = self.get_logits( + image_features, text_features, logit_scale + ) + + labels = self.get_ground_truth(device, logits_per_image.shape[0]) + + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + return {"contrastive_loss": total_loss} if output_dict else total_loss + + +class M3Patchilizer: + def __init__(self): + self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"] + self.regexPattern = ( + "(" + "|".join(map(re.escape, self.delimiters)) + ")" + ) + self.pad_token_id = 0 + self.bos_token_id = 1 + self.eos_token_id = 2 + self.mask_token_id = 3 + + def split_bars(self, body): + bars = re.split(self.regexPattern, "".join(body)) + bars = list(filter(None, bars)) # remove empty strings + if bars[0] in self.delimiters: + bars[1] = bars[0] + bars[1] + bars = bars[1:] + bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)] + return bars + + def bar2patch(self, bar, patch_size=PATCH_SIZE): + patch = ( + [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id] + ) + patch = patch[:patch_size] + patch += [self.pad_token_id] * (patch_size - len(patch)) + return patch + + def patch2bar(self, patch): + return "".join( + chr(idx) if idx > self.mask_token_id else "" for idx in patch + ) + + def encode( + self, + item, + patch_size=PATCH_SIZE, + add_special_patches=False, + truncate=False, + random_truncate=False, + ): + item = item.replace("L:1/8\n", "") + item = unidecode(item) + lines = re.findall(r".*?\n|.*$", item) + lines = list(filter(None, lines)) # remove empty lines + + patches = [] + + if lines[0].split(" ")[0] == "ticks_per_beat": + patch = "" + for line in lines: + if patch.startswith(line.split(" ")[0]) and ( + len(patch) + len(" ".join(line.split(" ")[1:])) + <= patch_size - 2 + ): + patch = patch[:-1] + "\t" + " ".join(line.split(" ")[1:]) + else: + if patch: + patches.append(patch) + patch = line + if patch != "": + patches.append(patch) + else: + for line in lines: + if len(line) > 1 and ( + (line[0].isalpha() and line[1] == ":") + or line.startswith("%%") + ): + patches.append(line) + else: + bars = self.split_bars(line) + if bars: + bars[-1] += "\n" + patches.extend(bars) + + if add_special_patches: + bos_patch = chr(self.bos_token_id) * patch_size + eos_patch = chr(self.eos_token_id) * patch_size + patches = [bos_patch] + patches + [eos_patch] + + if len(patches) > PATCH_LENGTH and truncate: + choices = ["head", "tail", "middle"] + choice = random.choice(choices) + if choice == "head" or random_truncate == False: + patches = patches[:PATCH_LENGTH] + elif choice == "tail": + patches = patches[-PATCH_LENGTH:] + else: + start = random.randint(1, len(patches) - PATCH_LENGTH) + patches = patches[start : start + PATCH_LENGTH] + + patches = [self.bar2patch(patch) for patch in patches] + + return patches + + def decode(self, patches): + return "".join(self.patch2bar(patch) for patch in patches) + + +class M3PatchEncoder(PreTrainedModel): + def __init__(self, config): + super(M3PatchEncoder, self).__init__(config) + self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, M3_HIDDEN_SIZE) + torch.nn.init.normal_(self.patch_embedding.weight, std=0.02) + self.base = BertModel(config=config) + self.pad_token_id = 0 + self.bos_token_id = 1 + self.eos_token_id = 2 + self.mask_token_id = 3 + + def forward( + self, + input_patches, # [batch_size, seq_length, hidden_size] + input_masks, + ): # [batch_size, seq_length] + # Transform input_patches into embeddings + input_patches = torch.nn.functional.one_hot( + input_patches, num_classes=128 + ) + input_patches = input_patches.reshape( + len(input_patches), -1, PATCH_SIZE * 128 + ).type(torch.FloatTensor) + input_patches = self.patch_embedding(input_patches.to(self.device)) + + # Apply BERT model to input_patches and input_masks + return self.base( + inputs_embeds=input_patches, attention_mask=input_masks + ) + + +class M3TokenDecoder(PreTrainedModel): + def __init__(self, config): + super(M3TokenDecoder, self).__init__(config) + self.base = GPT2LMHeadModel(config=config) + self.pad_token_id = 0 + self.bos_token_id = 1 + self.eos_token_id = 2 + self.mask_token_id = 3 + + def forward( + self, patch_features, target_patches # [batch_size, hidden_size] + ): # [batch_size, seq_length] + # get input embeddings + inputs_embeds = torch.nn.functional.embedding( + target_patches, self.base.transformer.wte.weight + ) + + # concatenate the encoded patches with the input embeddings + inputs_embeds = torch.cat( + (patch_features.unsqueeze(1), inputs_embeds[:, 1:, :]), dim=1 + ) + + # preparing the labels for model training + target_masks = target_patches == self.pad_token_id + target_patches = target_patches.clone().masked_fill_(target_masks, -100) + + # get the attention mask + target_masks = ~target_masks + target_masks = target_masks.type(torch.int) + + return self.base( + inputs_embeds=inputs_embeds, + attention_mask=target_masks, + labels=target_patches, + ) + + def generate(self, patch_feature, tokens): + # reshape the patch_feature and tokens + patch_feature = patch_feature.reshape(1, 1, -1) + tokens = tokens.reshape(1, -1) + + # get input embeddings + tokens = torch.nn.functional.embedding( + tokens, self.base.transformer.wte.weight + ) + + # concatenate the encoded patches with the input embeddings + tokens = torch.cat((patch_feature, tokens[:, 1:, :]), dim=1) + + # get the outputs from the model + outputs = self.base(inputs_embeds=tokens) + + # get the probabilities of the next token + probs = torch.nn.functional.softmax( + outputs.logits.squeeze(0)[-1], dim=-1 + ) + + return probs.detach().cpu().numpy() + + +class M3Model(PreTrainedModel): + def __init__(self, encoder_config, decoder_config): + super(M3Model, self).__init__(encoder_config) + self.encoder = M3PatchEncoder(encoder_config) + self.decoder = M3TokenDecoder(decoder_config) + self.pad_token_id = 0 + self.bos_token_id = 1 + self.eos_token_id = 2 + self.mask_token_id = 3 + + def forward( + self, + input_patches, # [batch_size, seq_length, hidden_size] + input_masks, # [batch_size, seq_length] + selected_indices, # [batch_size, seq_length] + target_patches, + ): # [batch_size, seq_length, hidden_size] + input_patches = input_patches.reshape( + len(input_patches), -1, PATCH_SIZE + ).to(self.device) + input_masks = input_masks.to(self.device) + selected_indices = selected_indices.to(self.device) + target_patches = target_patches.reshape( + len(target_patches), -1, PATCH_SIZE + ).to(self.device) + + # Pass the input_patches and input_masks through the encoder + outputs = self.encoder(input_patches, input_masks)["last_hidden_state"] + + # Use selected_indices to form target_patches + target_patches = target_patches[selected_indices.bool()] + patch_features = outputs[selected_indices.bool()] + + # Pass patch_features and target_patches through the decoder + return self.decoder(patch_features, target_patches) + + +class CLaMP3Model(PreTrainedModel): + def __init__( + self, + audio_config, + symbolic_config, + global_rank=None, + world_size=None, + text_model_name=TEXT_MODEL_NAME, + hidden_size=CLAMP3_HIDDEN_SIZE, + load_m3=CLAMP3_LOAD_M3, + ): + super(CLaMP3Model, self).__init__(symbolic_config) + + self.text_model = AutoModel.from_pretrained( + text_model_name + ) # Load the text model + self.text_proj = torch.nn.Linear( + self.text_model.config.hidden_size, hidden_size + ) # Linear layer for text projections + torch.nn.init.normal_( + self.text_proj.weight, std=0.02 + ) # Initialize weights with normal distribution + + self.symbolic_model = M3PatchEncoder( + symbolic_config + ) # Initialize the symbolic model + self.symbolic_proj = torch.nn.Linear( + M3_HIDDEN_SIZE, hidden_size + ) # Linear layer for symbolic projections + torch.nn.init.normal_( + self.symbolic_proj.weight, std=0.02 + ) # Initialize weights with normal distribution + + self.audio_model = BertModel(audio_config) # Initialize the audio model + self.audio_proj = torch.nn.Linear( + audio_config.hidden_size, hidden_size + ) # Linear layer for audio projections + torch.nn.init.normal_( + self.audio_proj.weight, std=0.02 + ) # Initialize weights with normal distribution + + if global_rank == None or world_size == None: + global_rank = 0 + world_size = 1 + + self.loss_fn = ClipLoss( + local_loss=False, + gather_with_grad=True, + cache_labels=False, + rank=global_rank, + world_size=world_size, + use_horovod=False, + ) + + if load_m3 and os.path.exists(M3_WEIGHTS_PATH): + checkpoint = torch.load( + M3_WEIGHTS_PATH, map_location="cpu", weights_only=True + ) + decoder_config = GPT2Config( + vocab_size=128, + n_positions=PATCH_SIZE, + n_embd=M3_HIDDEN_SIZE, + n_layer=TOKEN_NUM_LAYERS, + n_head=M3_HIDDEN_SIZE // 64, + n_inner=M3_HIDDEN_SIZE * 4, + ) + model = M3Model(symbolic_config, decoder_config) + model.load_state_dict(checkpoint["model"]) + self.symbolic_model = model.encoder + model = None + print( + f"Successfully Loaded M3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}" + ) + + def set_trainable(self, freeze_list): + if "text_model" in freeze_list: + self.text_model.eval() + for param in self.text_model.parameters(): + param.requires_grad = False + print("Text Model Frozen") + else: + self.text_model.train() + for param in self.text_model.parameters(): + param.requires_grad = True + print("Text Model Training") + + if "text_proj" in freeze_list: + self.text_proj.eval() + for param in self.text_proj.parameters(): + param.requires_grad = False + print("Text Projection Layer Frozen") + else: + self.text_proj.train() + for param in self.text_proj.parameters(): + param.requires_grad = True + print("Text Projection Layer Training") + + if "symbolic_model" in freeze_list: + self.symbolic_model.eval() + for param in self.symbolic_model.parameters(): + param.requires_grad = False + print("Symbolic Model Frozen") + else: + self.symbolic_model.train() + for param in self.symbolic_model.parameters(): + param.requires_grad = True + print("Symbolic Model Training") + + if "symbolic_proj" in freeze_list: + self.symbolic_proj.eval() + for param in self.symbolic_proj.parameters(): + param.requires_grad = False + print("Symbolic Projection Layer Frozen") + else: + self.symbolic_proj.train() + for param in self.symbolic_proj.parameters(): + param.requires_grad = True + print("Symbolic Projection Layer Training") + + if "audio_model" in freeze_list: + self.audio_model.eval() + for param in self.audio_model.parameters(): + param.requires_grad = False + print("Audio Model Frozen") + else: + self.audio_model.train() + for param in self.audio_model.parameters(): + param.requires_grad = True + print("Audio Model Training") + + if "audio_proj" in freeze_list: + self.audio_proj.eval() + for param in self.audio_proj.parameters(): + param.requires_grad = False + print("Audio Projection Layer Frozen") + else: + self.audio_proj.train() + for param in self.audio_proj.parameters(): + param.requires_grad = True + print("Audio Projection Layer Training") + + def avg_pooling(self, input_features, input_masks): + input_masks = input_masks.unsqueeze(-1).to( + self.device + ) # add a dimension to match the feature dimension + input_features = ( + input_features * input_masks + ) # apply mask to input_features + avg_pool = input_features.sum(dim=1) / input_masks.sum( + dim=1 + ) # calculate average pooling + + return avg_pool + + def get_text_features(self, text_inputs, text_masks, get_global=False): + text_features = self.text_model( + text_inputs.to(self.device), + attention_mask=text_masks.to(self.device), + )["last_hidden_state"] + + if get_global: + text_features = self.avg_pooling(text_features, text_masks) + text_features = self.text_proj(text_features) + + return text_features + + def get_symbolic_features( + self, symbolic_inputs, symbolic_masks, get_global=False + ): + symbolic_features = self.symbolic_model( + symbolic_inputs.to(self.device), symbolic_masks.to(self.device) + )["last_hidden_state"] + + if get_global: + symbolic_features = self.avg_pooling( + symbolic_features, symbolic_masks + ) + # symbolic_features = self.symbolic_proj(symbolic_features) + + return symbolic_features + + def get_audio_features(self, audio_inputs, audio_masks, get_global=False): + audio_features = self.audio_model( + inputs_embeds=audio_inputs.to(self.device), + attention_mask=audio_masks.to(self.device), + )["last_hidden_state"] + + if get_global: + audio_features = self.avg_pooling(audio_features, audio_masks) + audio_features = self.audio_proj(audio_features) + + return audio_features + + def forward( + self, + text_inputs, # [batch_size, seq_length] + text_masks, # [batch_size, seq_length] + music_inputs, # [batch_size, seq_length, hidden_size] + music_masks, # [batch_size, seq_length] + music_modality, + ): # "symbolic" or "audio" + # Compute the text features + text_features = self.get_text_features( + text_inputs, text_masks, get_global=True + ) + + # Compute the music features + if music_modality == "symbolic": + music_features = self.get_symbolic_features( + music_inputs, music_masks, get_global=True + ) + elif music_modality == "audio": + music_features = self.get_audio_features( + music_inputs, music_masks, get_global=True + ) + else: + raise ValueError( + "music_modality must be either 'symbolic' or 'audio'" + ) + + return self.loss_fn( + text_features, music_features, LOGIT_SCALE, output_dict=False + ) + + +def split_data(data, eval_ratio=EVAL_SPLIT): + random.shuffle(data) + split_idx = int(len(data) * eval_ratio) + eval_set = data[:split_idx] + train_set = data[split_idx:] + return train_set, eval_set + + +def mask_patches(target_patches, patchilizer, mode): + indices = list(range(len(target_patches))) + random.shuffle(indices) + selected_indices = indices[: math.ceil(M3_MASK_RATIO * len(indices))] + sorted_indices = sorted(selected_indices) + input_patches = torch.tensor(target_patches) + + if mode == "eval": + choice = "original" + else: + choice = random.choices( + ["mask", "shuffle", "original"], weights=[0.8, 0.1, 0.1] + )[0] + + if choice == "mask": + input_patches[sorted_indices] = torch.tensor( + [patchilizer.mask_token_id] * PATCH_SIZE + ) + elif choice == "shuffle": + for idx in sorted_indices: + patch = input_patches[idx] + try: + index_eos = (patch == patchilizer.eos_token_id).nonzero().item() + except: + index_eos = len(patch) + + indices = list(range(1, index_eos)) + random.shuffle(indices) + indices = [0] + indices + list(range(index_eos, len(patch))) + input_patches[idx] = patch[indices] + + selected_indices = torch.zeros(len(target_patches)) + selected_indices[sorted_indices] = 1.0 + + return input_patches, selected_indices + + +def remove_instrument_info(item): + # remove instrument information from symbolic music + lines = re.findall(r".*?\n|.*$", item) + lines = list(filter(None, lines)) + if lines[0].split(" ")[0] == "ticks_per_beat": + type = "mtf" + else: + type = "abc" + + cleaned_lines = [] + for line in lines: + if type == "abc" and line.startswith("V:"): + # find the position of " nm=" or " snm=" + nm_pos = line.find(" nm=") + snm_pos = line.find(" snm=") + # keep the part before " nm=" or " snm=" + if nm_pos != -1: + line = line[:nm_pos] + elif snm_pos != -1: + line = line[:snm_pos] + if nm_pos != -1 or snm_pos != -1: + line += "\n" + elif type == "mtf" and line.startswith("program_change"): + line = " ".join(line.split(" ")[:-1]) + " 0\n" + + cleaned_lines.append(line) + + return "".join(cleaned_lines) diff --git a/aria/eval/mert/__init__.py b/aria/eval/mert/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aria/eval/mert/emb.py b/aria/eval/mert/emb.py new file mode 100644 index 00000000..bcee3333 --- /dev/null +++ b/aria/eval/mert/emb.py @@ -0,0 +1,147 @@ +import torch +import tempfile +import shlex +import os +import torchaudio + +import torchaudio.transforms as T +import torch.nn.functional as F +import torch.nn as nn + +from ariautils.midi import MidiDict +from ariautils.tokenizer import AbsTokenizer + +from transformers import Wav2Vec2FeatureExtractor, AutoModel + + +def seq_to_audio_path( + seq: list, tokenizer: AbsTokenizer, pianoteq_exec_path: str +): + mid_temp = tempfile.NamedTemporaryFile(suffix=".mid", delete=False) + mid_path = mid_temp.name + mid_temp.close() + + mid = tokenizer.detokenize(seq) + mid.to_midi().save(mid_path) + + audio_temp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + audio_path = audio_temp.name + audio_temp.close() # Close so CLI can write to it + + preset = "NY Steinway D Classical Recording" + + pianoteq_cmd = f"{shlex.quote(pianoteq_exec_path)} --preset {shlex.quote(preset)} --rate 24000 --midi {mid_path} --wav {audio_path}" + os.system(pianoteq_cmd) + + os.remove(mid_path) + + return audio_path + + +def compute_audio_embedding( + audio_path: str, model: nn.Module, processor, delete_audio: bool = False +) -> torch.Tensor: + """ + Loads the MERT-v1-330M model and processor, reads an mp3 file, + segments the audio into 5-second chunks, computes a segment embedding by averaging + over the time dimension (for each layer) and across layers, and then aggregates + the segment embeddings using average pooling to produce a final embedding. + + Parameters: + file_path (str): Path to the mp3 audio file. + + Returns: + torch.Tensor: The final audio embedding. + """ + # Load the mp3 file and convert to mono if necessary (waveform shape: [channels, time]) + waveform, sr = torchaudio.load(audio_path) + + if waveform.size(0) > 1: + waveform = waveform.mean(dim=0, keepdim=True) + + # Resample if needed (target_sr for MERT-v1-330M is typically 24000 Hz) + target_sr = processor.sampling_rate + if sr != target_sr: + resampler = T.Resample(orig_freq=sr, new_freq=target_sr) + waveform = resampler(waveform) + + # Remove channel dimension to get [n_samples] + waveform = waveform.squeeze(0) + + # Define the segment length for 5 seconds + segment_length = target_sr * 5 + total_samples = waveform.size(0) + segments = [] + + # Split the waveform into segments; pad the final segment if needed + for start in range(0, total_samples, segment_length): + segment = waveform[start : start + segment_length] + if segment.size(0) < segment_length: + padding = segment_length - segment.size(0) + segment = F.pad(segment, (0, padding)) + segments.append(segment.numpy()) + + # Process all segments in one batch. The processor accepts a list of numpy arrays. + inputs = processor(segments, sampling_rate=target_sr, return_tensors="pt") + inputs = {k: v.cuda() for k, v in inputs.items()} + + # Forward pass through the model in batch mode + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True) + + # outputs.hidden_states is a tuple of tensors (one per layer) of shape: + # [batch_size, time_steps, feature_dim] for each layer. + # Stack them to get shape: [num_layers, batch_size, time_steps, feature_dim] + hidden_states = torch.stack(outputs.hidden_states) + + # Average over the time dimension for each segment in each layer: + # result shape: [num_layers, batch_size, feature_dim] + layer_time_avg = hidden_states.mean(dim=2) + + # Average over layers to obtain one embedding per segment: + # result shape: [batch_size, feature_dim] + segment_embeddings = layer_time_avg.mean(dim=0) + + # Finally, average the segment embeddings to get a final representation: + # shape: [feature_dim] + final_embedding = segment_embeddings.mean(dim=0) + + if delete_audio is True: + os.remove(audio_path) + + return final_embedding + + +def load_mert_model(): + + return AutoModel.from_pretrained( + "m-a-p/MERT-v1-330M", trust_remote_code=True + ).cuda(), Wav2Vec2FeatureExtractor.from_pretrained( + "m-a-p/MERT-v1-330M", trust_remote_code=True + ) + + +def main(): + model = AutoModel.from_pretrained( + "m-a-p/MERT-v1-330M", trust_remote_code=True + ).cuda() + processor = Wav2Vec2FeatureExtractor.from_pretrained( + "m-a-p/MERT-v1-330M", trust_remote_code=True + ) + + tokenizer = AbsTokenizer() + mid_dict = MidiDict.from_midi("/home/loubb/Dropbox/shared/test.mid") + seq = tokenizer.tokenize(mid_dict) + + audio_path = seq_to_audio_path(seq, tokenizer) + emb = compute_audio_embedding( + audio_path=audio_path, + model=model, + processor=processor, + delete_audio=True, + ) + print(emb.shape) + + +if __name__ == "__main__": + main() diff --git a/aria/inference/__init__.py b/aria/inference/__init__.py index f87bb30b..ceac4b4f 100644 --- a/aria/inference/__init__.py +++ b/aria/inference/__init__.py @@ -1 +1,58 @@ -from .model import TransformerLM +import torch + +from ariautils.tokenizer import AbsTokenizer +from ariautils.midi import MidiDict + + +def sample_min_p(probs: torch.Tensor, p_base: float) -> torch.Tensor: + """See - https://arxiv.org/pdf/2407.01082""" + p_max, _ = torch.max(probs, dim=-1, keepdim=True) + p_scaled = p_base * p_max + mask = probs >= p_scaled + + masked_probs = probs.clone() + masked_probs[~mask] = 0.0 + masked_probs.div_(masked_probs.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(masked_probs, num_samples=1) + + return next_token + + +def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + probs_sort[mask] = 0.0 + + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + + return next_token + + +def get_cfg_prompt(prompts: list): + cfg_prompts = [] + for prompt in prompts: + cfg_prompts.append(prompt) + cfg_prompts.append(prompt) + + return cfg_prompts + + +def get_inference_prompt( + midi_dict: MidiDict, tokenizer: AbsTokenizer, prompt_len_ms: int +): + midi_dict.note_msgs = [ + msg + for msg in midi_dict.note_msgs + if midi_dict.tick_to_ms(msg["data"]["start"]) <= prompt_len_ms + ] + + if len(midi_dict.note_msgs) == 0: + return [("prefix", "instrument", "piano"), tokenizer.bos_tok] + + seq = tokenizer.tokenize(midi_dict=midi_dict, add_dim_tok=False) + seq.remove(tokenizer.eos_tok) + + return seq diff --git a/aria/inference/model.py b/aria/inference/model_cuda.py similarity index 87% rename from aria/inference/model.py rename to aria/inference/model_cuda.py index d4bf4392..8dbfbd4f 100644 --- a/aria/inference/model.py +++ b/aria/inference/model_cuda.py @@ -1,4 +1,4 @@ -"""Inference implementation with torch-compiler friendly kv-cache.""" +"""Inference implementation for torch (cuda) backend""" import torch import torch.nn as nn @@ -34,99 +34,6 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out -class TransformerLM(nn.Module): - def __init__(self, model_config: ModelConfig): - super().__init__() - self.model_config = model_config - self.max_seq_len = model_config.max_seq_len - self.model = Transformer(model_config) - self.lm_head = nn.Linear( - model_config.d_model, model_config.vocab_size, bias=False - ) - - def forward( - self, - idxs: torch.Tensor, - input_pos: torch.Tensor, - pad_idxs: torch.Tensor | None = None, - ): - hidden_states = self.model( - idxs=idxs, - input_pos=input_pos, - pad_idxs=pad_idxs, - ) - logits = self.lm_head(hidden_states) - - return logits - - def setup_cache( - self, - batch_size, - max_seq_len=4096, - dtype=torch.bfloat16, - ): - # Init cache - for b in self.model.encode_layers: - b.kv_cache = KVCache( - max_batch_size=batch_size, - max_seq_length=max_seq_len, - n_heads=self.model_config.n_heads, - head_dim=self.model_config.d_model // self.model_config.n_heads, - dtype=dtype, - ).cuda() - - self.model.freqs_cis = precompute_freqs_cis( - seq_len=max_seq_len, - n_elem=self.model_config.d_model // self.model_config.n_heads, - base=500000, - dtype=dtype, - ).cuda() - self.model.causal_mask = torch.tril( - torch.ones(max_seq_len, max_seq_len, dtype=torch.bool) - ).cuda() - - -class Transformer(nn.Module): - def __init__(self, model_config: ModelConfig) -> None: - super().__init__() - self.model_config = model_config - - self.tok_embeddings = nn.Embedding( - num_embeddings=model_config.vocab_size, - embedding_dim=model_config.d_model, - ) - self.encode_layers = nn.ModuleList( - TransformerBlock(model_config) for _ in range(model_config.n_layers) - ) - self.out_layer_norm = nn.LayerNorm(model_config.d_model) - - self.freqs_cis = None - self.casual_mask = None - - def forward( - self, - idxs: torch.Tensor, - input_pos: torch.Tensor, - pad_idxs: torch.Tensor | None = None, - ): - assert self.freqs_cis is not None, "Caches must be initialized first" - - mask = self.causal_mask[None, None, input_pos] - - if pad_idxs is not None: - mask = mask & ~(pad_idxs.unsqueeze(1).unsqueeze(1)) - - freqs_cis = self.freqs_cis[input_pos] - - x = self.tok_embeddings(idxs) - for layer in self.encode_layers: - x = layer(x, input_pos, freqs_cis, mask) - - x = self.out_layer_norm(x) - - return x - - class TransformerBlock(nn.Module): def __init__(self, model_config: ModelConfig) -> None: super().__init__() @@ -169,7 +76,6 @@ def __init__(self, model_config: ModelConfig) -> None: self.norm1 = nn.LayerNorm(model_config.d_model) self.norm2 = nn.LayerNorm(model_config.d_model) - # TODO: Fill in args self.kv_cache = None def forward( @@ -238,6 +144,123 @@ def _ff_block(self, x: torch.Tensor): ) +class Transformer(nn.Module): + def __init__(self, model_config: ModelConfig) -> None: + super().__init__() + self.model_config = model_config + + self.tok_embeddings = nn.Embedding( + num_embeddings=model_config.vocab_size, + embedding_dim=model_config.d_model, + ) + self.encode_layers = nn.ModuleList( + TransformerBlock(model_config) for _ in range(model_config.n_layers) + ) + self.out_layer_norm = nn.LayerNorm(model_config.d_model) + + self.freqs_cis = None + self.causal_mask = None + + def fill_condition_kv(self, emb: torch.Tensor): + assert self.freqs_cis is not None, "Caches must be initialized first" + assert self.model_config.emb_size is not None + + input_pos = torch.tensor([0], device=emb.device) + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + + x = emb.unsqueeze(dim=1) + + for layer in self.encode_layers: + x = layer(x, input_pos, freqs_cis, mask) + + def forward( + self, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, + ): + assert self.freqs_cis is not None, "Caches must be initialized first" + + mask = self.causal_mask[None, None, input_pos] + + if pad_idxs is not None: + mask = mask & ~(pad_idxs.unsqueeze(1).unsqueeze(1)) + + freqs_cis = self.freqs_cis[input_pos] + + x = self.tok_embeddings(idxs) + for layer in self.encode_layers: + x = layer(x, input_pos, freqs_cis, mask) + + x = self.out_layer_norm(x) + + return x + + +class TransformerLM(nn.Module): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + self.max_seq_len = model_config.max_seq_len + self.model = Transformer(model_config) + self.lm_head = nn.Linear( + model_config.d_model, model_config.vocab_size, bias=False + ) + + if model_config.emb_size is not None: + self.embedding_adapter = nn.Linear( + model_config.emb_size, model_config.d_model, bias=False + ) + + def forward( + self, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, + ): + hidden_states = self.model( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, + ) + logits = self.lm_head(hidden_states) + + return logits + + def fill_condition_kv(self, cond_emb: torch.Tensor): + assert self.model_config.emb_size is not None + + adapted_emb = self.embedding_adapter(cond_emb) + self.model.fill_condition_kv(emb=adapted_emb) + + def setup_cache( + self, + batch_size: int, + max_seq_len=8096, + dtype=torch.bfloat16, + ): + assert batch_size >= 1 + for b in self.model.encode_layers: + b.kv_cache = KVCache( + max_batch_size=batch_size, + max_seq_length=max_seq_len, + n_heads=self.model_config.n_heads, + head_dim=self.model_config.d_model // self.model_config.n_heads, + dtype=dtype, + ).cuda() + + self.model.freqs_cis = precompute_freqs_cis( + seq_len=max_seq_len, + n_elem=self.model_config.d_model // self.model_config.n_heads, + base=500000, + dtype=dtype, + ).cuda() + self.model.causal_mask = torch.tril( + torch.ones(max_seq_len, max_seq_len, dtype=torch.bool) + ).cuda() + + def precompute_freqs_cis( seq_len: int, n_elem: int, @@ -255,7 +278,8 @@ def precompute_freqs_cis( return cache.to(dtype=dtype) -@torch.jit.script +# TODO: Fix +# @torch.jit.script def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """ In-place RoPE. Credits to Katherine Crowson: diff --git a/aria/inference/model_mlx.py b/aria/inference/model_mlx.py new file mode 100644 index 00000000..169b30bc --- /dev/null +++ b/aria/inference/model_mlx.py @@ -0,0 +1,284 @@ +"""Inference implementation for mlx backend""" + +import mlx.core as mx +import mlx.nn as nn + +from aria.model import ModelConfig + + +class KVCache(nn.Module): + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + dtype: mx.Dtype = mx.float32, + ): + super().__init__() + self.dtype = dtype + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.k_cache = mx.zeros(cache_shape, dtype=dtype) + self.v_cache = mx.zeros(cache_shape, dtype=dtype) + + def update(self, input_pos: mx.array, k_val: mx.array, v_val: mx.array): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +class TransformerBlock(nn.Module): + def __init__( + self, + model_config: ModelConfig, + ): + super().__init__() + self.d_model = model_config.d_model + self.n_heads = model_config.n_heads + self.d_head = self.d_model // self.n_heads + self.max_seq_len = model_config.max_seq_len + self.scale = self.d_head**-0.5 + + # Att + self.mixed_qkv = nn.Linear( + input_dims=model_config.d_model, + output_dims=3 * model_config.d_model, + bias=False, + ) + self.att_proj_linear = nn.Linear( + input_dims=model_config.d_model, + output_dims=model_config.d_model, + bias=False, + ) + + # FF + self.ff_gate_proj = nn.Linear( + input_dims=model_config.d_model, + output_dims=model_config.d_model * model_config.ff_mult, + bias=False, + ) + self.ff_up_proj = nn.Linear( + input_dims=model_config.d_model, + output_dims=model_config.d_model * model_config.ff_mult, + bias=False, + ) + self.ff_down_proj = nn.Linear( + input_dims=model_config.d_model * model_config.ff_mult, + output_dims=model_config.d_model, + bias=False, + ) + + # Pre layer norms + self.norm1 = nn.LayerNorm(model_config.d_model) + self.norm2 = nn.LayerNorm(model_config.d_model) + + self.kv_cache = None + + def __call__( + self, + x: mx.array, + input_pos: mx.array, + offset: int, + mask: mx.array, + ): + assert self.kv_cache is not None, "Cache not initialized" + + x += self._att_block( + x=self.norm1(x), + input_pos=input_pos, + offset=offset, + mask=mask, + ) + x = x + self._ff_block(self.norm2(x)) + + return x + + def get_kv(self, k: mx.array, v: mx.array, input_pos: mx.array): + k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=input_pos) + + return k, v + + def _att_block( + self, + x: mx.array, + input_pos: mx.array, + offset: int, + mask: mx.array, + ): + + qkv_splits = self.mixed_qkv(x).split(3, axis=2) + q, k, v = qkv_splits[0], qkv_splits[1], qkv_splits[2] + + batch_size, seq_len, _ = q.shape + q = q.reshape(batch_size, seq_len, self.n_heads, self.d_head) + k = k.reshape(batch_size, seq_len, self.n_heads, self.d_head) + v = v.reshape(batch_size, seq_len, self.n_heads, self.d_head) + + q = apply_rotary_emb_mlx(q, offset=offset) + k = apply_rotary_emb_mlx(k, offset=offset) + q, k, v = map(lambda x: x.transpose(0, 2, 1, 3), (q, k, v)) + + k, v = self.get_kv(k, v, input_pos=input_pos) + wv = mx.fast.scaled_dot_product_attention( + q=q, + k=k, + v=v, + scale=self.scale, + mask=mask, + ) + + # (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d) + wv = wv.transpose(0, 2, 1, 3).reshape( + batch_size, seq_len, self.n_heads * self.d_head + ) + + return self.att_proj_linear(wv) + + def _ff_block(self, x: mx.array): + return self.ff_down_proj( + nn.silu(self.ff_gate_proj(x)) * self.ff_up_proj(x) + ) + + +class Transformer(nn.Module): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + + self.tok_embeddings = nn.Embedding( + num_embeddings=model_config.vocab_size, + dims=model_config.d_model, + ) + self.encode_layers = [ + TransformerBlock(model_config) for _ in range(model_config.n_layers) + ] + self.out_layer_norm = nn.LayerNorm(model_config.d_model) + + def fill_condition_kv(self, emb: mx.array): + assert self.causal_mask is not None, "Caches must be initialized first" + assert self.model_config.emb_size is not None + + input_pos = mx.array([0], dtype=mx.int32) + mask = self.causal_mask[None, None, input_pos] + offset = 0 + + x = mx.expand_dims(emb, axis=1) + + for layer in self.encode_layers: + x = layer(x, input_pos, offset, mask) + + def __call__( + self, + idxs: mx.array, + input_pos: mx.array, + offset: int, + pad_idxs: mx.array | None = None, + ): + assert self.causal_mask is not None, "Caches must be initialized first" + + mask = self.causal_mask[None, None, input_pos] + + if pad_idxs is not None: + pad_mask = mx.expand_dims(mx.expand_dims(pad_idxs, axis=1), axis=1) + mask = mask & ~pad_mask + + x = self.tok_embeddings(idxs) + for layer in self.encode_layers: + x = layer(x, input_pos, offset, mask) + + x = self.out_layer_norm(x) + + return x + + +class TransformerLM(nn.Module): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + self.max_seq_len = model_config.max_seq_len + self.model = Transformer(model_config) # Implement + self.lm_head = nn.Linear( + model_config.d_model, model_config.vocab_size, bias=False + ) + + if model_config.emb_size is not None: + self.embedding_adapter = nn.Linear( + model_config.emb_size, model_config.d_model, bias=False + ) + + def __call__( + self, + idxs: mx.array, + input_pos: mx.array, + offset: int, + pad_idxs: mx.array | None = None, + ): + hidden_states = self.model( + idxs=idxs, + input_pos=input_pos, + offset=offset, + pad_idxs=pad_idxs, + ) + logits = self.lm_head(hidden_states) + + return logits + + def fill_condition_kv(self, cond_emb: mx.array): + assert self.model_config.emb_size is not None + + adapted_emb = self.embedding_adapter(cond_emb) + self.model.fill_condition_kv(emb=adapted_emb) + + def setup_cache( + self, + batch_size, + max_seq_len=8096, + dtype=mx.float32, + ): + # Init cache + for b in self.model.encode_layers: + b.kv_cache = KVCache( + max_batch_size=batch_size, + max_seq_length=max_seq_len, + n_heads=self.model_config.n_heads, + head_dim=self.model_config.d_model // self.model_config.n_heads, + dtype=dtype, + ) + + self.model.causal_mask = mx.tril( + mx.ones((max_seq_len, max_seq_len), dtype=mx.bool_) + ) + + +def apply_rotary_emb_mlx( + x: mx.array, + offset: int = 0, +) -> mx.array: + # Original x shape: (b_sz, s_len, n_head, d_head) + original_shape = x.shape + b_sz, s_len, n_head, d_head = original_shape + + # Transpose to (b_sz, n_head, s_len, d_head) + x_permuted = x.transpose(0, 2, 1, 3) + # Reshape for mx.fast.rope: (b_sz * n_head, s_len, d_head) + x_reshaped = x_permuted.reshape(-1, s_len, d_head) + + rotated_x_reshaped = mx.fast.rope( + x_reshaped, + dims=d_head, + traditional=False, + base=500000, + scale=1.0, + offset=offset, + ) + + rotated_x_permuted = rotated_x_reshaped.reshape(b_sz, n_head, s_len, d_head) + rotated_x = rotated_x_permuted.transpose(0, 2, 1, 3) + + return rotated_x diff --git a/aria/inference/sample_cuda.py b/aria/inference/sample_cuda.py new file mode 100644 index 00000000..909bd8d7 --- /dev/null +++ b/aria/inference/sample_cuda.py @@ -0,0 +1,361 @@ +"""Contains generation/sampling code""" + +import torch +import torch._inductor.config + +from tqdm import tqdm + +from aria.inference import sample_min_p, sample_top_p +from aria.inference.model_cuda import TransformerLM +from ariautils.tokenizer import Tokenizer, AbsTokenizer + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True + +DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + + +def get_cfg_prompt(prompts: list): + cfg_prompts = [] + for prompt in prompts: + cfg_prompts.append(prompt) + cfg_prompts.append(prompt) + + return cfg_prompts + + +@torch.inference_mode() +def decode_one( + model: TransformerLM, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, +) -> torch.Tensor: + assert input_pos.shape[-1] == 1 + + logits = model.forward( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, + )[:, -1] + + return logits + + +@torch.inference_mode() +def prefill( + model: TransformerLM, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, +) -> torch.Tensor: + logits = model.forward( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, + ) + + return logits + + +def update_seq_ids_( + seq: torch.Tensor, + idx: int, + next_token_ids: torch.Tensor, + dim_tok_inserted: list, + eos_tok_seen: list, + max_len: int, + force_end: bool, + tokenizer: Tokenizer, +): + # Insert dim and pad toks + for _idx in range(seq.shape[0]): + if eos_tok_seen[_idx] == True: + next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.pad_tok] + elif ( + force_end + and idx >= max_len - 130 + and dim_tok_inserted[_idx] is False + and tokenizer.id_to_tok[next_token_ids[_idx].item()][0] + not in ("dur", "onset") + ): + next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok] + + # Update dim_tok_inserted and eos_tok_seen + if next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]: + dim_tok_inserted[_idx] = True + elif next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.eos_tok]: + eos_tok_seen[_idx] = True + + seq[:, idx] = next_token_ids + + +@torch.autocast("cuda", dtype=DTYPE) +@torch.inference_mode() +def sample_batch( + model: TransformerLM, + tokenizer: Tokenizer, + prompt: list, + num_variations: list, + max_new_tokens: int, + temp: float, + force_end: bool = False, + top_p: float | None = None, + min_p: float | None = None, + compile: bool = False, +): + assert top_p is not None or min_p is not None + assert 0.0 <= temp <= 2.0 + if top_p is not None: + assert 0.5 <= top_p <= 1.0 + if min_p is not None: + assert 0.0 <= min_p <= 1.0 + if force_end: + assert max_new_tokens > 130, "prompt too long to use force_end=True" + + prompt_len = len(prompt) + + model = model.cuda() + model.eval() + dim_tok_inserted = [False for _ in range(num_variations)] + eos_tok_seen = [False for _ in range(num_variations)] + total_len = prompt_len + max_new_tokens + seq = torch.stack( + [ + torch.tensor( + tokenizer.encode( + prompt + [tokenizer.pad_tok] * (total_len - prompt_len) + ) + ) + for _ in range(num_variations) + ] + ).cuda() + + if compile is True: + global decode_one + decode_one = torch.compile( + decode_one, + mode="reduce-overhead", + fullgraph=True, + ) + + model.setup_cache( + batch_size=num_variations, + max_seq_len=total_len, + dtype=DTYPE, + ) + + print( + f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, gen_len={max_new_tokens}" + ) + + for idx in ( + pbar := tqdm( + range(prompt_len, total_len), + total=total_len - prompt_len, + leave=False, + ) + ): + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + if idx == prompt_len: + logits = prefill( + model, + idxs=seq[:, :idx], + input_pos=torch.arange(0, idx, device=seq.device), + )[:, -1] + else: + logits = decode_one( + model, + idxs=seq[:, idx - 1 : idx], + input_pos=torch.tensor( + [(idx) - 1], + device=seq.device, + dtype=torch.int, + ), + ) + + if temp > 0.0: + probs = torch.softmax(logits / temp, dim=-1) + if min_p is not None: + next_token_ids = sample_min_p(probs, min_p).flatten() + else: + next_token_ids = sample_top_p(probs, top_p).flatten() + else: + next_token_ids = torch.argmax(logits, dim=-1).flatten() + + update_seq_ids_( + seq=seq, + idx=idx, + next_token_ids=next_token_ids, + dim_tok_inserted=dim_tok_inserted, + eos_tok_seen=eos_tok_seen, + max_len=total_len, + force_end=force_end, + tokenizer=tokenizer, + ) + + if all(seen_eos is True for seen_eos in eos_tok_seen): + break + + decoded_results = [tokenizer.decode(s) for s in seq.tolist()] + decoded_results = [ + ( + res[: res.index(tokenizer.eos_tok) + 1] + if tokenizer.eos_tok in res + else res + ) + for res in decoded_results + ] + + return decoded_results + + +@torch.autocast("cuda", dtype=DTYPE) +@torch.inference_mode() +def sample_batch_cfg( + model: TransformerLM, + tokenizer: AbsTokenizer, + prompt: list, + num_variations: list, + max_new_tokens: int, + cfg_gamma: float, + embedding: list[float], + temp: float, + force_end=False, + top_p: float | None = None, + min_p: float | None = None, + compile: bool = False, +): + assert top_p is not None or min_p is not None + assert 0.0 <= temp <= 2.0 + assert 0.0 <= cfg_gamma <= 10.0 + if top_p is not None: + assert 0.5 <= top_p <= 1.0 + if min_p is not None: + assert 0.0 <= min_p <= 1.0 + if force_end: + assert max_new_tokens > 130, "prompt too long to use force_end=True" + + prompt_len = len(prompt) + num_variations = 2 * num_variations # For CFG + + model = model.cuda() + model.eval() + dim_tok_inserted = [False for _ in range(num_variations)] + eos_tok_seen = [False for _ in range(num_variations)] + total_len = prompt_len + max_new_tokens + seq = torch.stack( + [ + torch.tensor( + tokenizer.encode( + prompt + [tokenizer.pad_tok] * (total_len - prompt_len) + ) + ) + for _ in range(num_variations) + ] + ).cuda() + + if compile is True: + global decode_one + decode_one = torch.compile( + decode_one, + mode="reduce-overhead", + fullgraph=True, + ) + + model.setup_cache( + batch_size=num_variations, + max_seq_len=total_len, + dtype=DTYPE, + ) + + condition_embedding = torch.tensor( + [embedding for _ in range(num_variations)], device=seq.device + ) + model.fill_condition_kv(cond_emb=condition_embedding) + embedding_offset = 1 + pad_idxs = torch.zeros_like(seq, dtype=torch.bool) + pad_idxs[1::2, 0] = True + + print( + f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, cfg={cfg_gamma}, gen_len={max_new_tokens}" + ) + + CFG_WARM_UP_STEPS = min(250, max_new_tokens) + curr_step = 0 + for idx in ( + pbar := tqdm( + range(prompt_len, total_len), + total=total_len - prompt_len, + leave=False, + ) + ): + if idx == prompt_len: + logits = prefill( + model, + idxs=seq[:, :idx], + input_pos=torch.arange( + embedding_offset, + idx + embedding_offset, + device=seq.device, + ), + pad_idxs=pad_idxs, + )[:, -1] + else: + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.MATH + ): + logits = decode_one( + model, + idxs=seq[:, idx - 1 : idx], + input_pos=torch.tensor( + [(idx + embedding_offset) - 1], + device=seq.device, + dtype=torch.int, + ), + pad_idxs=pad_idxs, + ) + + curr_step += 1 + _cfg_gamma = min(cfg_gamma, (curr_step / CFG_WARM_UP_STEPS) * cfg_gamma) + + logits_cfg = _cfg_gamma * logits[::2] + (1 - _cfg_gamma) * logits[1::2] + logits_cfg[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + + if temp > 0.0: + probs = torch.softmax(logits_cfg / temp, dim=-1) + if min_p is not None: + next_token_ids = sample_min_p(probs, min_p).flatten() + else: + next_token_ids = sample_top_p(probs, top_p).flatten() + else: + next_token_ids = torch.argmax(logits_cfg, dim=-1).flatten() + + next_token_ids = next_token_ids.repeat_interleave(2) + + update_seq_ids_( + seq=seq, + idx=idx, + next_token_ids=next_token_ids, + dim_tok_inserted=dim_tok_inserted, + eos_tok_seen=eos_tok_seen, + max_len=total_len, + force_end=force_end, + tokenizer=tokenizer, + ) + + if all(seen_eos is True for seen_eos in eos_tok_seen): + break + + decoded_results = [tokenizer.decode(s) for s in seq.tolist()][::2] + decoded_results = [ + ( + res[: res.index(tokenizer.eos_tok) + 1] + if tokenizer.eos_tok in res + else res + ) + for res in decoded_results + ] + + return decoded_results diff --git a/aria/inference/sample_mlx.py b/aria/inference/sample_mlx.py new file mode 100644 index 00000000..c6a6e258 --- /dev/null +++ b/aria/inference/sample_mlx.py @@ -0,0 +1,337 @@ +"""Contains generation/sampling code (mlx)""" + +import torch +import numpy as np +import mlx.core as mx + +from tqdm import tqdm + +from aria.inference import sample_min_p, sample_top_p +from aria.inference.model_mlx import TransformerLM +from ariautils.tokenizer import AbsTokenizer + +DTYPE = mx.float32 + + +def decode_one( + model: TransformerLM, + idxs: mx.array, + input_pos: mx.array, + pad_idxs: mx.array | None = None, +): + assert input_pos.shape[-1] == 1 + + logits = model( + idxs=idxs, + input_pos=input_pos, + offset=input_pos[0], + pad_idxs=pad_idxs, + )[:, -1] + + return logits + + +def prefill( + model: TransformerLM, + idxs: mx.array, + input_pos: mx.array, + pad_idxs: mx.array | None = None, +): + logits = model( + idxs=idxs, + input_pos=input_pos, + offset=input_pos[0], + pad_idxs=pad_idxs, + ) + + return logits + + +def update_seq_ids_( + seq: mx.array, + idx: int, + next_token_ids: mx.array, + dim_tok_inserted: list, + eos_tok_seen: list, + max_len: int, + force_end: bool, + tokenizer: AbsTokenizer, +): + # Insert dim and pad toks + for _idx in range(seq.shape[0]): + if eos_tok_seen[_idx] == True: + next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.pad_tok] + elif ( + force_end + and idx >= max_len - 130 + and dim_tok_inserted[_idx] is False + and tokenizer.id_to_tok[next_token_ids[_idx].item()][0] + not in ("dur", "onset") + ): + next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok] + + # Update dim_tok_inserted and eos_tok_seen + if next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]: + dim_tok_inserted[_idx] = True + elif next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.eos_tok]: + eos_tok_seen[_idx] = True + + seq[:, idx] = next_token_ids + + +def sample_batch( + model: TransformerLM, + tokenizer: AbsTokenizer, + prompt: list, + num_variations: list, + max_new_tokens: int, + temp: float = 0.95, + force_end: bool = False, + top_p: float | None = None, + min_p: float | None = None, +): + assert top_p is not None or min_p is not None + assert 0.0 <= temp <= 2.0 + if top_p is not None: + assert 0.5 <= top_p <= 1.0 + if min_p is not None: + assert 0.0 <= min_p <= 1.0 + if force_end: + assert max_new_tokens > 130, "prompt too long to use force_end=True" + + prompt_len = len(prompt) + + model.eval() + dim_tok_inserted = [False for _ in range(num_variations)] + eos_tok_seen = [False for _ in range(num_variations)] + total_len = prompt_len + max_new_tokens + + seq = mx.stack( + [ + mx.array( + tokenizer.encode( + prompt + [tokenizer.pad_tok] * (total_len - prompt_len) + ), + dtype=mx.int32, + ) + for _ in range(num_variations) + ], + ) + model.setup_cache( + batch_size=num_variations, + max_seq_len=total_len, + dtype=DTYPE, + ) + print( + f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, gen_len={max_new_tokens}" + ) + + for idx in ( + pbar := tqdm( + range(prompt_len, total_len), + total=total_len - prompt_len, + leave=False, + ) + ): + if idx == prompt_len: + logits = prefill( + model, + idxs=seq[:, :idx], + input_pos=mx.arange(0, idx, dtype=mx.int32), + )[:, -1] + else: + logits = decode_one( + model, + idxs=seq[:, idx - 1 : idx], + input_pos=mx.array( + [idx - 1], + dtype=mx.int32, + ), + ) + + if temp > 0.0: + probs = mx.softmax(logits / temp, axis=-1) + if min_p is not None: + next_token_ids = sample_min_p_mlx(probs, min_p).flatten() + else: + next_token_ids = sample_top_p_mlx(probs, top_p).flatten() + else: + next_token_ids = mx.argmax(logits, axis=-1).flatten() + + update_seq_ids_( + seq=seq, + idx=idx, + next_token_ids=next_token_ids, + dim_tok_inserted=dim_tok_inserted, + eos_tok_seen=eos_tok_seen, + max_len=total_len, + force_end=force_end, + tokenizer=tokenizer, + ) + + if all(seen_eos is True for seen_eos in eos_tok_seen): + break + + decoded_results = [tokenizer.decode(s) for s in seq.tolist()] + decoded_results = [ + ( + res[: res.index(tokenizer.eos_tok) + 1] + if tokenizer.eos_tok in res + else res + ) + for res in decoded_results + ] + + return decoded_results + + +def sample_batch_cfg( + model: TransformerLM, + tokenizer: AbsTokenizer, + prompt: list, + num_variations: list, + max_new_tokens: int, + cfg_gamma: float, + embedding: list[float], + temp: float, + force_end=False, + top_p: float | None = None, + min_p: float | None = None, +): + assert top_p is not None or min_p is not None + assert 0.0 <= temp <= 2.0 + assert 0.0 <= cfg_gamma <= 10.0 + if top_p is not None: + assert 0.5 <= top_p <= 1.0 + if min_p is not None: + assert 0.0 <= min_p <= 1.0 + if force_end: + assert max_new_tokens > 130, "prompt too long to use force_end=True" + + prompt_len = len(prompt) + num_variations = 2 * num_variations # For CFG + + model.eval() + dim_tok_inserted = [False for _ in range(num_variations)] + eos_tok_seen = [False for _ in range(num_variations)] + total_len = prompt_len + max_new_tokens + seq = mx.stack( + [ + mx.array( + tokenizer.encode( + prompt + [tokenizer.pad_tok] * (total_len - prompt_len) + ), + dtype=mx.int32, + ) + for _ in range(num_variations) + ] + ) + + model.setup_cache( + batch_size=num_variations, + max_seq_len=total_len, + dtype=DTYPE, + ) + + condition_embedding = mx.array( + [embedding for _ in range(num_variations)], + dtype=DTYPE, + ) + model.fill_condition_kv(cond_emb=condition_embedding) + embedding_offset = 1 + pad_idxs = mx.zeros(seq.shape, dtype=mx.bool_) + pad_idxs[1::2, 0] = True + + print( + f"Using hyperparams: temp={temp}, top_p={top_p}, min_p={min_p}, cfg={cfg_gamma}, gen_len={max_new_tokens}" + ) + + CFG_WARM_UP_STEPS = min(250, max_new_tokens) + curr_step = 0 + for idx in ( + pbar := tqdm( + range(prompt_len, total_len), + total=total_len - prompt_len, + leave=False, + ) + ): + if idx == prompt_len: + logits = prefill( + model, + idxs=seq[:, :idx], + input_pos=mx.arange( + embedding_offset, idx + embedding_offset, dtype=mx.int32 + ), + pad_idxs=pad_idxs, + )[:, -1] + else: + logits = decode_one( + model, + idxs=seq[:, idx - 1 : idx], + input_pos=mx.array( + [(idx + embedding_offset) - 1], + dtype=mx.int32, + ), + pad_idxs=pad_idxs, + ) + + curr_step += 1 + _cfg_gamma = min(cfg_gamma, (curr_step / CFG_WARM_UP_STEPS) * cfg_gamma) + + logits_cfg = _cfg_gamma * logits[::2] + (1 - _cfg_gamma) * logits[1::2] + logits_cfg[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + + if temp > 0.0: + probs = mx.softmax(logits_cfg / temp, axis=-1) + + if min_p is not None: + next_token_ids = sample_min_p_mlx(probs, min_p).flatten() + else: + next_token_ids = sample_top_p_mlx(probs, top_p).flatten() + else: + next_token_ids = mx.argmax(logits_cfg, axis=-1).flatten() + + next_token_ids = mx.repeat(next_token_ids, repeats=2) + + update_seq_ids_( + seq=seq, + idx=idx, + next_token_ids=next_token_ids, + dim_tok_inserted=dim_tok_inserted, + eos_tok_seen=eos_tok_seen, + max_len=total_len, + force_end=force_end, + tokenizer=tokenizer, + ) + + if all(seen_eos is True for seen_eos in eos_tok_seen): + break + + decoded_results = [tokenizer.decode(s) for s in seq.tolist()][::2] + decoded_results = [ + ( + res[: res.index(tokenizer.eos_tok) + 1] + if tokenizer.eos_tok in res + else res + ) + for res in decoded_results + ] + + return decoded_results + + +def sample_min_p_mlx(probs: mx.array, p_base: float) -> mx.array: + """See - https://arxiv.org/pdf/2407.01082""" + + probs_t = torch.from_numpy(np.array(probs)) + next_token_t = sample_min_p(probs=probs_t, p_base=p_base) + + return mx.array(next_token_t, dtype=mx.int32) + + +def sample_top_p_mlx(probs: mx.array, top_p: float) -> mx.array: + + probs_t = torch.from_numpy(np.array(probs)) + next_token_t = sample_top_p(probs=probs_t, top_p=top_p) + + return mx.array(next_token_t, dtype=mx.int32) diff --git a/aria/model.py b/aria/model.py index b4753f7d..573f5484 100644 --- a/aria/model.py +++ b/aria/model.py @@ -1,4 +1,4 @@ -"""Training implementation.""" +"""Training model implementation.""" from dataclasses import dataclass from typing import Optional @@ -19,20 +19,23 @@ class ModelConfig: drop_p: float max_seq_len: int grad_checkpoint: bool + resid_dropout: float = 0.0 vocab_size: Optional[int] = None + class_size: Optional[int] = None + emb_size: Optional[dict] = None def set_vocab_size(self, vocab_size: int): self.vocab_size = vocab_size class FusedEncoderBlock(nn.Module): - def __init__(self, model_config: ModelConfig): + def __init__(self, model_config: ModelConfig, resid_dropout: float = 0.0): super().__init__() - self.drop_p = model_config.drop_p self.n_heads = model_config.n_heads self.d_head = model_config.d_model // model_config.n_heads self.max_seq_len = model_config.max_seq_len + self.resid_dropout = resid_dropout # Attention self.mixed_qkv = nn.Linear( @@ -68,8 +71,11 @@ def __init__(self, model_config: ModelConfig): self.norm2 = nn.LayerNorm(model_config.d_model) def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): - x = x + self._att_block(self.norm1(x), freqs_cis) - x = x + self._ff_block(self.norm2(x)) + att_out = self._att_block(self.norm1(x), freqs_cis) + x = x + F.dropout(att_out, p=self.resid_dropout, training=self.training) + + ff_out = self._ff_block(self.norm2(x)) + x = x + F.dropout(ff_out, p=self.resid_dropout, training=self.training) return x @@ -115,10 +121,10 @@ def _ff_block(self, x: torch.Tensor): class Transformer(nn.Module): - """Transformer decoder with no language model head. + """Transformer decoder without a language model head. Args: - model_config (ModelConfig): Model config settings. + model_config (ModelConfig): Model configuration settings. """ def __init__(self, model_config: ModelConfig): @@ -133,29 +139,40 @@ def __init__(self, model_config: ModelConfig): self.out_layer_norm = nn.LayerNorm(model_config.d_model) self.encode_layers = nn.ModuleList() - for _ in range(model_config.n_layers): - self.encode_layers.append(FusedEncoderBlock(model_config)) + + for layer_index in range(model_config.n_layers): + if model_config.resid_dropout > 0: + layer_dropout = model_config.resid_dropout * ( + layer_index / (model_config.n_layers - 1) + ) + else: + layer_dropout = 0.0 + + self.encode_layers.append( + FusedEncoderBlock(model_config, resid_dropout=layer_dropout) + ) def forward( self, src: torch.Tensor, + emb: torch.Tensor | None = None, ): - """Forward pass of Transformer. + """Perform a forward pass through the transformer. Args: - src (torch.tensor): Input to encoder block, of shape (batch_size, - seq_len, d_model). - attn_mask (Optional[torch.tensor]): Attention mask of shape - (batch_size, seq_len). Defaults to None. - past_kv (Optional[list[KVCache]]): a list of kv caches. The list index - corresponds to the layer index. + src (torch.Tensor): Input tensor of token indices with shape (batch_size, seq_len). + emb (Optional[torch.Tensor]): Optional extra embedding with shape (batch_size, emb_dim). Returns: - torch.tensor: Model outputs with shape (batch_size, seq_len, - d_model). + torch.Tensor: Output tensor with shape (batch_size, seq_len, d_model). """ + hidden_states = self.tok_embeddings(src) + if emb is not None: + emb = emb[:, None, :] + hidden_states = torch.cat([emb, hidden_states[:, :-1, :]], dim=1) + if self.freqs_cis is None: self.freqs_cis = precompute_freqs_cis( seq_len=self.model_config.max_seq_len, @@ -165,7 +182,7 @@ def forward( ).to(src.device) freqs_cis = self.freqs_cis[: src.shape[1]] - if self.model_config.grad_checkpoint is True: + if self.model_config.grad_checkpoint is True and self.training: for layer in self.encode_layers: def create_custom_forward(module): @@ -189,14 +206,15 @@ def custom_forward(*args): class TransformerLM(nn.Module): - """Transformer decoder with head for language modelling. + """Transformer decoder with a language modeling head. Args: - model_config (ModelConfig): Model config settings. + model_config (ModelConfig): Model configuration settings (vocab_size must be defined). """ def __init__(self, model_config: ModelConfig): super().__init__() + assert model_config.vocab_size is not None self.max_seq_len = model_config.max_seq_len self.model = Transformer(model_config) @@ -208,26 +226,155 @@ def forward( self, src: torch.Tensor, ): - """Forward pass of Transformer decoder with LM head. + """Compute language modeling logits. Args: - src (torch.tensor): Input to encoder block, of shape (batch_size, - seq_len, d_model). - attn_mask (Optional[torch.tensor]): Attention mask of shape - (batch_size, seq_len). Defaults to None. - past_kv (Optional[list[KVCache]]): a list of kv caches. The list index - corresponds to the layer index. + src (torch.Tensor): Input tensor of token indices with shape (batch_size, seq_len). Returns: - torch.tensor: Forward pass of src through Transformer and LM head. - Has shape (batch_size, seq_len, vocab_size). + torch.Tensor: Logits with shape (batch_size, seq_len, vocab_size). """ + hidden = self.model(src) logits = self.lm_head(hidden) return logits +class TransformerCL(nn.Module): + """Transformer decoder with a classification head. + + Args: + model_config (ModelConfig): Model configuration settings (class_size must be defined). + """ + + def __init__(self, model_config: ModelConfig): + super().__init__() + assert model_config.class_size is not None + + self.max_seq_len = model_config.max_seq_len + self.model = Transformer(model_config) + self.class_head = nn.Linear( + model_config.d_model, model_config.class_size, bias=False + ) + + def forward( + self, + src: torch.Tensor, + ): + """Compute classification logits. + + Args: + src (torch.Tensor): Input tensor of token indices with shape (batch_size, seq_len). + + Returns: + torch.Tensor: Classification logits with shape (batch_size, seq_len, class_size). + """ + + hidden = self.model(src) + logits = self.class_head(hidden) + + return logits + + +class TransformerLM_CND(nn.Module): + """Transformer decoder with a language modeling head and optional conditioning. + + Args: + model_config (ModelConfig): Model configuration settings (vocab_size and emb_size must be defined). + """ + + def __init__(self, model_config: ModelConfig): + super().__init__() + assert model_config.vocab_size is not None + + self.max_seq_len = model_config.max_seq_len + self.model = Transformer(model_config) + self.lm_head = nn.Linear( + model_config.d_model, model_config.vocab_size, bias=False + ) + self.embedding_adapter = nn.Linear( + model_config.emb_size, model_config.d_model, bias=False + ) + + def forward( + self, + src: torch.Tensor, + emb: torch.Tensor | None = None, + ): + """Compute language modeling logits with optional conditioning. + + Args: + src (torch.Tensor): Input tensor of token indices with shape (batch_size, seq_len). + emb (Optional[torch.Tensor]): Optional conditioning embedding with shape (batch_size, emb_size). + + Returns: + torch.Tensor: Logits with shape (batch_size, seq_len, vocab_size). + Note that if the emb is provided, the seq_len will be seq_len -1. + + """ + + if emb is not None: + # Embedding is prepended to sequence via the adapter. We slice the + # logits so that the logits format still matches src. + emb = self.embedding_adapter(emb) + hidden = self.model(src, emb) + logits = self.lm_head(hidden) + + return logits[:, 1:, :] + else: + # Needed for torch dpp error + dummy_input = torch.zeros( + src.size(0), + self.embedding_adapter.in_features, + device=src.device, + ) + dummy_output = self.embedding_adapter(dummy_input) + dummy_loss = dummy_output.sum() * 0.0 + + hidden = self.model(src, None) + logits = self.lm_head(hidden) + logits = logits + dummy_loss + + return logits + + +class TransformerEMB(nn.Module): + """Transformer decoder with an embedding head. + + Args: + model_config (ModelConfig): Model configuration settings (emb_size must be defined). + """ + + def __init__(self, model_config: ModelConfig): + super().__init__() + assert model_config.emb_size is not None + + self.max_seq_len = model_config.max_seq_len + self.model = Transformer(model_config) + self.emb_head = nn.Linear( + model_config.d_model, model_config.emb_size, bias=False + ) + + def forward( + self, + src: torch.Tensor, + ): + """Compute output embeddings from the transformer. + + Args: + src (torch.Tensor): Input tensor of token indices with shape (batch_size, seq_len). + + Returns: + torch.Tensor: Output embeddings with shape (batch_size, seq_len, emb_size). + """ + + hidden = self.model(src) + emb = self.emb_head(hidden) + + return emb + + def precompute_freqs_cis( seq_len: int, n_elem: int, diff --git a/aria/run.py b/aria/run.py index 182d1f53..71c319b7 100644 --- a/aria/run.py +++ b/aria/run.py @@ -2,220 +2,455 @@ import argparse import os -import re +import json import sys -def _parse_sample_args(): - argp = argparse.ArgumentParser(prog="aria sample") - argp.add_argument("-m", help="name of model config file") - argp.add_argument("-c", help="path to model checkpoint") - argp.add_argument("-p", help="path to midi file") +def _parse_generate_args(): + argp = argparse.ArgumentParser(prog="aria generate") argp.add_argument( - "-temp", + "--backend", + choices=["torch_cuda", "mlx"], + default="torch_cuda", + help="backend for inference", + ) + argp.add_argument( + "--checkpoint_path", + help="path to model used for decoding", + required=True, + ) + argp.add_argument( + "--prompt_midi_path", + help="path to midi file", + required=True, + ) + argp.add_argument( + "--prompt_duration", + help="length of the input MIDI prompt, in seconds", + type=int, + default=15, + ) + argp.add_argument( + "--variations", + help="number of variations to generate", + type=int, + default=1, + ) + argp.add_argument( + "--temp", help="sampling temperature value", type=float, required=False, - default=0.95, + default=0.98, ) argp.add_argument( - "-top_p", - help="sampling top_p value", + "--min_p", + help="sampling min_p value", type=float, + default=0.035, required=False, - default=0.95, ) argp.add_argument( - "-cfg", - help="sampling cfg gamma value", + "--top_p", + help="sampling top_p value", type=float, required=False, ) argp.add_argument( - "-metadata", - nargs=2, - metavar=("KEY", "VALUE"), - action="append", - help="manually add metadata key-value pair when sampling", + "--end", action="store_true", help="generate ending for piece" ) argp.add_argument( - "-var", - help="number of variations", + "--length", type=int, - default=1, + help="number of tokens to generate per variation", + default=2048, + ) + argp.add_argument( + "--compile", + action="store_true", + help="use torch compiler to generate cudagraph for inference", + ) + argp.add_argument( + "--save_dir", + type=str, + default=".", + help="directory to save generated MIDI files", + ) + + return argp.parse_args(sys.argv[2:]) + + +def _parse_conditioned_generate_args(): + argp = argparse.ArgumentParser(prog="aria generate") + argp.add_argument( + "--backend", + choices=["torch_cuda", "mlx"], + default="torch_cuda", + help="backend for inference", + ) + argp.add_argument( + "--checkpoint_path", + help="path to model used for decoding", + required=True, ) argp.add_argument( - "-trunc", - help="length (in seconds) of the prompt", + "--prompt_midi_path", + help="path to midi file", + required=True, + ) + argp.add_argument( + "--prompt_duration", + help="length of the input MIDI prompt, in seconds", type=int, - default=20, + default=15, ) - argp.add_argument("-e", action="store_true", help="enable force end") - argp.add_argument("-l", type=int, help="generation length", default=1024) argp.add_argument( - "-guidance_path", type=str, help="path to guidance MIDI", required=False + "--embedding_model_checkpoint_path", + help="path to model checkpoint used for embeddings", + required=True, ) argp.add_argument( - "-guidance_start_ms", - help="guidance interval start (ms)", + "--embedding_midi_path", + help="path to MIDI file used for conditioning", + required=True, + ) + argp.add_argument( + "--variations", + help="number of variations to generate", type=int, + default=1, + ) + argp.add_argument( + "--temp", + help="sampling temperature value", + type=float, required=False, + default=0.98, ) argp.add_argument( - "-guidance_end_ms", - help="guidance interval end (ms)", - type=int, + "--cfg", + help="sampling cfg gamma value", + type=float, + default=1.0, + ) + argp.add_argument( + "--min_p", + help="sampling min_p value", + type=float, + default=0.035, required=False, ) - argp.add_argument("-compile", action="store_true", help="compile cudagraph") + argp.add_argument( + "--top_p", + help="sampling top_p value", + type=float, + required=False, + ) + argp.add_argument( + "--end", action="store_true", help="generate ending for piece" + ) + argp.add_argument( + "--length", + type=int, + help="number of tokens to generate per variation", + default=2048, + ) + argp.add_argument( + "--compile", + action="store_true", + help="use torch compiler to generate cudagraph for inference", + ) + argp.add_argument( + "--save_dir", + type=str, + default=".", + help="directory to save generated MIDI files", + ) return argp.parse_args(sys.argv[2:]) -def sample(args): - """Entrypoint for sampling""" +def _get_prompt( + midi_path: str, + prompt_duration_s: int, +): + from ariautils.midi import MidiDict + from ariautils.tokenizer import AbsTokenizer + from aria.inference import get_inference_prompt - from torch.cuda import is_available as cuda_is_available - from aria.inference import TransformerLM + return get_inference_prompt( + midi_dict=MidiDict.from_midi(midi_path), + tokenizer=AbsTokenizer(), + prompt_len_ms=1e3 * prompt_duration_s, + ) + + +def _load_embedding_model(checkpoint_path: str): + from safetensors.torch import load_file + + from ariautils.tokenizer import AbsTokenizer + from aria.model import TransformerEMB, ModelConfig + from aria.config import load_model_config + + model_config = ModelConfig(**load_model_config(name="medium-emb")) + model_config.set_vocab_size(AbsTokenizer().vocab_size) + model = TransformerEMB(model_config) + + state_dict = load_file(filename=checkpoint_path) + model.load_state_dict(state_dict=state_dict, strict=True) + + return model + + +def _load_inference_model_torch( + checkpoint_path: str, + config_name: str, + strict: bool = True, +): + from safetensors.torch import load_file + + from ariautils.tokenizer import AbsTokenizer + from aria.inference.model_cuda import TransformerLM + from aria.model import ModelConfig + from aria.config import load_model_config + + model_config = ModelConfig(**load_model_config(name=config_name)) + model_config.set_vocab_size(AbsTokenizer().vocab_size) + model = TransformerLM(model_config) + + state_dict = load_file(filename=checkpoint_path) + model.load_state_dict(state_dict=state_dict, strict=strict) + + return model + + +def _load_inference_model_mlx( + checkpoint_path: str, + config_name: str, + strict: bool = True, +): + import mlx.core as mx + + from ariautils.tokenizer import AbsTokenizer + from aria.inference.model_mlx import TransformerLM from aria.model import ModelConfig - from aria.config import load_model_config, load_config - from aria.tokenizer import InferenceAbsTokenizer - from aria.sample import ( - sample_batch_cfg, - sample_batch, - get_inference_prompt, + from aria.config import load_model_config + + model_config = ModelConfig(**load_model_config(name=config_name)) + model_config.set_vocab_size(AbsTokenizer().vocab_size) + model = TransformerLM(model_config) + model.load_weights(checkpoint_path, strict=strict) + mx.eval(model.parameters()) + + return model + + +def generate(args): + from ariautils.tokenizer import AbsTokenizer + + num_variations = args.variations + prompt_duration_s = args.prompt_duration + backend = args.backend + max_new_tokens = args.length + + assert num_variations > 0 + assert prompt_duration_s >= 0 + assert max_new_tokens > 0 + assert os.path.isdir(args.save_dir) + + tokenizer = AbsTokenizer() + prompt = _get_prompt( + args.prompt_midi_path, + prompt_duration_s=prompt_duration_s, ) - from ariautils.midi import MidiDict - from aria.utils import _load_weight + max_new_tokens = min(8096 - len(prompt), max_new_tokens) - if not cuda_is_available(): - raise Exception("CUDA device is not available.") + if backend == "torch_cuda": + from torch.cuda import is_available + from aria.inference.sample_cuda import sample_batch as sample_batch_t - model_state = _load_weight(args.c, "cuda") - model_state = { - k.replace("_orig_mod.", ""): v for k, v in model_state.items() - } + assert is_available(), "CUDA not available" - manual_metadata = {k: v for k, v in args.metadata} if args.metadata else {} - valid_metadata = load_config()["data"]["metadata"]["manual"] - for k, v in manual_metadata.copy().items(): - assert k in valid_metadata.keys(), f"{manual_metadata} is invalid" - if v not in valid_metadata[k]: - print(f"Ignoring invalid manual metadata: {k}") - print(f"Please choose from {valid_metadata[k]}") - del manual_metadata[k] - - num_variations = args.var - truncate_len = args.trunc - force_end = args.e - model_name = args.m - - tokenizer = InferenceAbsTokenizer() - model_config = ModelConfig(**load_model_config(model_name)) - model_config.set_vocab_size(tokenizer.vocab_size) - model_config.grad_checkpoint = False - model = TransformerLM(model_config).cuda() - - try: - model.load_state_dict(model_state) - except Exception as e: - print( - "Failed to load model_state. This is likely due to an incompatibility " - "between the checkpoint file (-c) and model name/config (-m)." + model = _load_inference_model_torch( + checkpoint_path=args.checkpoint_path, + config_name="medium", + strict=False, + ) + results = sample_batch_t( + model=model, + tokenizer=tokenizer, + prompt=prompt, + num_variations=num_variations, + max_new_tokens=max_new_tokens, + temp=args.temp, + force_end=args.end, + top_p=args.top_p, + min_p=args.min_p, + compile=args.compile, ) - raise e + elif backend == "mlx": + from aria.inference.sample_mlx import sample_batch as sample_batch_mlx - assert args.l > 0, "Generation length must be positive." - max_new_tokens = args.l + model = _load_inference_model_mlx( + checkpoint_path=args.checkpoint_path, + config_name="medium", + strict=False, + ) + results = sample_batch_mlx( + model=model, + tokenizer=tokenizer, + prompt=prompt, + num_variations=num_variations, + max_new_tokens=max_new_tokens, + temp=args.temp, + force_end=args.end, + top_p=args.top_p, + min_p=args.min_p, + ) - # Load and format prompts and metadata - midi_dict = MidiDict.from_midi(mid_path=args.p) - if args.guidance_path: - guidance_midi_dict = MidiDict.from_midi(mid_path=args.guidance_path) - else: - guidance_midi_dict = None + for idx, tokenized_seq in enumerate(results): + res_midi_dict = tokenizer.detokenize(tokenized_seq) + res_midi = res_midi_dict.to_midi() + res_midi.save(os.path.join(args.save_dir, f"res_{idx + 1}.mid")) - for k, v in manual_metadata.items(): - midi_dict.metadata[k] = v + print(f"Results saved to {os.path.realpath(args.save_dir)}") - print(f"Extracted metadata: {midi_dict.metadata}") - print( - f"Instruments: {set([MidiDict.get_program_to_instrument()[msg['data']] for msg in midi_dict.instrument_msgs])}" + +def _get_embedding( + embedding_model_checkpoints_path: str, + embedding_midi_path: str, +): + from aria.embedding import get_global_embedding_from_midi + + model = _load_embedding_model( + checkpoint_path=embedding_model_checkpoints_path + ).cpu() + global_embedding = get_global_embedding_from_midi( + model=model, + midi_path=embedding_midi_path, + device="cpu", ) - prompt_seq, guidance_seq = get_inference_prompt( - tokenizer=tokenizer, - midi_dict=midi_dict, - truncate_len=truncate_len, - guidance_start_ms=args.guidance_start_ms, - guidance_end_ms=args.guidance_end_ms, - guidance_midi_dict=guidance_midi_dict, + return global_embedding.tolist() + + +def conditioned_generate(args): + from ariautils.tokenizer import AbsTokenizer + + num_variations = args.variations + prompt_duration_s = args.prompt_duration + backend = args.backend + max_new_tokens = args.length + + assert num_variations > 0 + assert prompt_duration_s >= 0 + assert max_new_tokens > 0 + assert os.path.isdir(args.save_dir) + + tokenizer = AbsTokenizer() + prompt = _get_prompt( + args.prompt_midi_path, + prompt_duration_s=prompt_duration_s, + ) + embedding = _get_embedding( + embedding_model_checkpoints_path=args.embedding_model_checkpoint_path, + embedding_midi_path=args.embedding_midi_path, ) + max_new_tokens = min(8096 - len(prompt), max_new_tokens) - if len(prompt_seq) + args.l > model_config.max_seq_len: - print( - "WARNING: Required context exceeds max_seq_len supported by model" + if backend == "torch_cuda": + from torch.cuda import is_available + from aria.inference.sample_cuda import ( + sample_batch_cfg as sample_batch_cfg_t, ) - prompts = [prompt_seq for _ in range(num_variations)] - - samples_dir = os.path.join(os.path.dirname(__file__), "..", "samples") - if os.path.isdir(samples_dir) is False: - os.mkdir(samples_dir) - if guidance_seq: - tokenizer.detokenize(guidance_seq).to_midi().save( - os.path.join(samples_dir, f"guidance.mid") + + assert is_available(), "CUDA not available" + + model = _load_inference_model_torch( + checkpoint_path=args.checkpoint_path, + config_name="medium-emb", + strict=True, ) - if args.cfg is not None and guidance_seq is not None: - results = sample_batch_cfg( + results = sample_batch_cfg_t( model=model, tokenizer=tokenizer, - prompts=prompts, + prompt=prompt, + num_variations=num_variations, max_new_tokens=max_new_tokens, cfg_gamma=args.cfg, - force_end=force_end, - temperature=args.temp, + embedding=embedding, + temp=args.temp, + force_end=args.end, top_p=args.top_p, + min_p=args.min_p, compile=args.compile, ) - else: - results = sample_batch( + + elif backend == "mlx": + from aria.inference.sample_mlx import ( + sample_batch_cfg as sample_batch_cfg_mlx, + ) + + model = _load_inference_model_mlx( + checkpoint_path=args.checkpoint_path, + config_name="medium-emb", + strict=True, + ) + results = sample_batch_cfg_mlx( model=model, tokenizer=tokenizer, - prompts=prompts, + prompt=prompt, + num_variations=num_variations, max_new_tokens=max_new_tokens, - force_end=force_end, - temperature=args.temp, + cfg_gamma=args.cfg, + embedding=embedding, + temp=args.temp, + force_end=args.end, top_p=args.top_p, - compile=args.compile, + min_p=args.min_p, ) for idx, tokenized_seq in enumerate(results): res_midi_dict = tokenizer.detokenize(tokenized_seq) res_midi = res_midi_dict.to_midi() - res_midi.save(os.path.join(samples_dir, f"res_{idx + 1}.mid")) + res_midi.save(os.path.join(args.save_dir, f"res_{idx + 1}.mid")) - print("Results saved to samples/") + print(f"Results saved to {os.path.realpath(args.save_dir)}") def _parse_midi_dataset_args(): argp = argparse.ArgumentParser(prog="aria midi-dataset") - argp.add_argument("dir", help="directory containing midi files") - argp.add_argument("save_path", help="path to save dataset") - argp.add_argument("-r", action="store_true", help="recursively search dirs") argp.add_argument( - "-s", action="store_true", help="shuffle dataset", default=False + "dir", + help="directory containing midi files", ) argp.add_argument( - "-metadata", + "save_path", + help="path to save dataset", + ) + argp.add_argument( + "--recursive", + action="store_true", + help="recursively search dirs", + ) + argp.add_argument( + "--shuffle", + action="store_true", + help="shuffle dataset", + ) + argp.add_argument( + "--split", + type=float, + help="create train/val split", + required=False, + ) + argp.add_argument( + "--metadata", nargs=2, metavar=("KEY", "VALUE"), action="append", help="manually add metadata key-value pair when building dataset", ) - argp.add_argument( - "-split", type=float, help="create train/val split", required=False - ) - return argp.parse_args(sys.argv[2:]) @@ -228,10 +463,10 @@ def build_midi_dataset(args): MidiDataset.build_to_file( dir=args.dir, save_path=args.save_path, - recur=args.r, + recur=args.recursive, overwrite=True, manual_metadata=manual_metadata, - shuffle=args.s, + shuffle=args.shuffle, ) if args.split: @@ -243,15 +478,47 @@ def build_midi_dataset(args): ) +# TODO: Add turn - to -- flags def _parse_pretrain_dataset_args(): argp = argparse.ArgumentParser(prog="aria pretrain-dataset") - argp.add_argument("-load_path", help="path midi_dict dataset") - argp.add_argument("-save_dir", help="path to save dataset") argp.add_argument( - "-tokenizer_name", help="tokenizer name", choices=["abs", "rel"] + "--load_path", + help="path midi_dict dataset", + required=True, + ) + argp.add_argument( + "--save_dir", + help="path to save dataset", + required=True, + ) + argp.add_argument( + "--tokenizer_name", + help="tokenizer name", + choices=["abs", "rel"], + required=True, + ) + argp.add_argument( + "--seq_len", + help="sequence length (tokens)", + type=int, + default=4096, + ) + argp.add_argument( + "--num_epochs", + help="number of epochs to build", + type=int, + default=1, + ) + argp.add_argument( + "--sep_sequences", + help="start each with a new entry", + action="store_true", + ) + argp.add_argument( + "--embedding_dataset_path", + help="path to embedding dataset - same format as EvaluationDataset", + required=False, ) - argp.add_argument("-l", help="max sequence length", type=int, default=4096) - argp.add_argument("-e", help="num epochs", type=int, default=1) return argp.parse_args(sys.argv[2:]) @@ -265,39 +532,24 @@ def build_pretraining_dataset(args): elif args.tokenizer_name == "rel": tokenizer = RelTokenizer() - PretrainingDataset.build( - tokenizer=tokenizer, - save_dir=args.save_dir, - max_seq_len=args.l, - num_epochs=args.e, - midi_dataset_path=args.load_path, - ) - - -def _parse_finetune_dataset_args(): - argp = argparse.ArgumentParser(prog="aria finetune-dataset") - argp.add_argument( - "-midi_dataset_path", - help="path to midi_dict dataset", - ) - argp.add_argument("-save_dir", help="path to save dataset") - argp.add_argument("-l", help="max sequence length", type=int, default=4096) - argp.add_argument("-e", help="num epochs", type=int, default=1) - - return argp.parse_args(sys.argv[2:]) - + if args.embedding_dataset_path is not None: + with open(args.embedding_dataset_path, "r") as f: + file_embeddings = { + data["metadata"]["abs_load_path"]: data["emb"] + for data in map(json.loads, f) + } -def build_finetune_dataset(args): - from aria.tokenizer import InferenceAbsTokenizer - from aria.datasets import FinetuningDataset + else: + file_embeddings = None - tokenizer = InferenceAbsTokenizer() - FinetuningDataset.build( + PretrainingDataset.build( tokenizer=tokenizer, save_dir=args.save_dir, - max_seq_len=args.l, - num_epochs=args.e, - midi_dataset_path=args.midi_dataset_path, + max_seq_len=args.seq_len, + num_epochs=args.num_epochs, + midi_dataset_path=args.load_path, + separate_sequences=args.sep_sequences, + file_embeddings=file_embeddings, ) @@ -308,29 +560,27 @@ def main(): "command", help="command to run", choices=( - "sample", + "generate", + "conditioned-generate", "midi-dataset", "pretrain-dataset", - "finetune-dataset", ), ) - # parse_args defaults to [1:] for args, but you need to - # exclude the rest of the args too, or validation will fail args = parser.parse_args(sys.argv[1:2]) if not hasattr(args, "command"): parser.print_help() print("Unrecognized command") exit(1) - elif args.command == "sample": - sample(args=_parse_sample_args()) + elif args.command == "generate": + generate(args=_parse_generate_args()) + elif args.command == "conditioned-generate": + conditioned_generate(args=_parse_conditioned_generate_args()) elif args.command == "midi-dataset": build_midi_dataset(args=_parse_midi_dataset_args()) elif args.command == "pretrain-dataset": build_pretraining_dataset(args=_parse_pretrain_dataset_args()) - elif args.command == "finetune-dataset": - build_finetune_dataset(args=_parse_finetune_dataset_args()) else: print("Unrecognized command") parser.print_help() diff --git a/aria/sample.py b/aria/sample.py deleted file mode 100644 index 01f9abea..00000000 --- a/aria/sample.py +++ /dev/null @@ -1,405 +0,0 @@ -"""Contains generation/sampling code""" - -import copy -import torch -import torch._dynamo.config -import torch._inductor.config - -from typing import List -from tqdm import tqdm - -from aria.inference import TransformerLM -from aria.tokenizer import InferenceAbsTokenizer -from ariautils.tokenizer import Tokenizer, AbsTokenizer -from ariautils.midi import MidiDict - -torch._inductor.config.coordinate_descent_tuning = True -torch._inductor.config.triton.unique_kernel_names = True -torch._inductor.config.fx_graph_cache = True - - -def get_cfg_prompt(prompts: list, pad_tok: str, guidance_end_tok: str): - cfg_prompts = [] - for prompt in prompts: - prompt_no_guidance = prompt[prompt.index(guidance_end_tok) + 1 :] - prompt_no_guidance = [pad_tok] * ( - len(prompt) - len(prompt_no_guidance) - ) + prompt_no_guidance - cfg_prompts.append(prompt) - cfg_prompts.append(prompt_no_guidance) - - return cfg_prompts - - -@torch.inference_mode() -def decode_one( - model: TransformerLM, - idxs: torch.Tensor, - input_pos: torch.Tensor, - pad_idxs: torch.Tensor | None = None, -): - logits = model.forward( - idxs=idxs, - input_pos=input_pos, - pad_idxs=pad_idxs, - )[:, -1] - - return logits - - -@torch.inference_mode() -def prefill( - model: TransformerLM, - idxs: torch.Tensor, - input_pos: torch.Tensor, - pad_idxs: torch.Tensor | None = None, -): - logits = model.forward(idxs=idxs, input_pos=input_pos, pad_idxs=pad_idxs)[ - :, -1 - ] - - return logits - - -def update_seq_ids_( - seq: torch.Tensor, - idx: int, - next_token_ids: torch.Tensor, - dim_tok_inserted: list, - eos_tok_seen: list, - max_len: int, - force_end: bool, - tokenizer: Tokenizer, -): - # Insert dim and pad toks - for _idx in range(seq.shape[0]): - if eos_tok_seen[_idx] == True: - next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.pad_tok] - elif ( - force_end - and idx >= max_len - 130 - and dim_tok_inserted[_idx] is False - and tokenizer.id_to_tok[next_token_ids[_idx].item()][0] - not in ("dur", "onset") - ): - next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok] - - # Update dim_tok_inserted and eos_tok_seen - if next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]: - dim_tok_inserted[_idx] = True - elif next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.eos_tok]: - eos_tok_seen[_idx] = True - - seq[:, idx] = next_token_ids - - -# TODO: Not working -@torch.autocast( - "cuda", - dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, -) -@torch.inference_mode() -def sample_batch( - model: TransformerLM, - tokenizer: Tokenizer, - prompts: List[list], - max_new_tokens: int, - force_end=False, - temperature: float = 0.95, - top_p: float = 0.95, - compile: bool = False, -): - if force_end: - assert max_new_tokens > 130, "prompt too long to use force_end=True" - - _prompt_len = len(prompts[0]) - _num_prompts = len(prompts) - assert all([len(p) == _prompt_len for p in prompts]) - - model.eval() - dim_tok_inserted = [False for _ in range(_num_prompts)] - eos_tok_seen = [False for _ in range(_num_prompts)] - total_len = _prompt_len + max_new_tokens - seq = torch.stack( - [ - torch.tensor( - tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p))) - ) - for p in prompts - ] - ).cuda() - - if compile is True: - global decode_one - decode_one = torch.compile( - decode_one, - mode="reduce-overhead", - fullgraph=True, - ) - - model.setup_cache( - batch_size=_num_prompts, - max_seq_len=total_len, - dtype=( - torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - ), - ) - - print( - f"Using hyperparams: temp={temperature}, top_p={top_p}, gen_len={max_new_tokens}" - ) - - for idx in ( - pbar := tqdm( - range(_prompt_len, total_len), - total=total_len - _prompt_len, - leave=False, - ) - ): - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - if idx == _prompt_len: - logits = prefill( - model, - idxs=seq[:, :idx], - input_pos=torch.arange(0, idx, device=seq.device), - ) - else: - logits = decode_one( - model, - idxs=seq[:, idx - 1 : idx], - input_pos=torch.tensor( - [idx - 1], device=seq.device, dtype=torch.int - ), - ) - - if tokenizer.name == "inference_abs": - logits[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( - "-inf" - ) - - if temperature > 0.0: - probs = torch.softmax(logits / temperature, dim=-1) - next_token_ids = sample_top_p(probs, top_p).flatten() - else: - next_token_ids = torch.argmax(logits, dim=-1).flatten() - - update_seq_ids_( - seq=seq, - idx=idx, - next_token_ids=next_token_ids, - dim_tok_inserted=dim_tok_inserted, - eos_tok_seen=eos_tok_seen, - max_len=total_len, - force_end=force_end, - tokenizer=tokenizer, - ) - - if all(seen_eos is True for seen_eos in eos_tok_seen): - break - - decoded_results = [tokenizer.decode(s) for s in seq.tolist()] - decoded_results = [ - ( - res[: res.index(tokenizer.eos_tok) + 1] - if tokenizer.eos_tok in res - else res - ) - for res in decoded_results - ] - - return decoded_results - - -@torch.autocast( - "cuda", - dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, -) -@torch.inference_mode() -def sample_batch_cfg( - model: TransformerLM, - tokenizer: InferenceAbsTokenizer, - prompts: List[list], - max_new_tokens: int, - cfg_gamma: float, - force_end=False, - temperature: float = 0.95, - top_p: float = 0.95, - compile: bool = False, -): - assert 0.0 <= cfg_gamma <= 2.0 - assert 0.0 <= temperature <= 2.0 - assert 0.5 <= top_p <= 1.0 - assert tokenizer.name == "inference_abs" - if force_end: - assert max_new_tokens > 130, "prompt too long to use force_end=True" - - prompts = get_cfg_prompt( - prompts, tokenizer.pad_tok, tokenizer.guidance_end_tok - ) - - _prompt_len = len(prompts[0]) - _num_prompts = len(prompts) - assert all([len(p) == _prompt_len for p in prompts]) - - model.eval() - total_len = _prompt_len + max_new_tokens - seq = torch.stack( - [ - torch.tensor( - tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p))) - ) - for p in prompts - ] - ).cuda() - dim_tok_inserted = [False for _ in range(_num_prompts)] - eos_tok_seen = [False for _ in range(_num_prompts)] - - if compile is True: - global decode_one - decode_one = torch.compile( - decode_one, - mode="reduce-overhead", - fullgraph=True, - ) - - model.setup_cache( - batch_size=_num_prompts, - max_seq_len=total_len, - dtype=( - torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - ), - ) - - print( - f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}" - ) - - for idx in ( - pbar := tqdm( - range(_prompt_len, total_len), - total=total_len - _prompt_len, - leave=False, - ) - ): - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - if idx == _prompt_len: - logits = prefill( - model, - idxs=seq[:, :idx], - input_pos=torch.arange(0, idx, device=seq.device), - pad_idxs=(seq == tokenizer.pad_id), - ) - else: - logits = decode_one( - model, - idxs=seq[:, idx - 1 : idx], - input_pos=torch.tensor( - [idx - 1], device=seq.device, dtype=torch.int - ), - pad_idxs=(seq == tokenizer.pad_id), - ) - - logits_cfg = cfg_gamma * logits[::2] + (1 - cfg_gamma) * logits[1::2] - logits_cfg[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( - "-inf" - ) - logits_cfg[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") - - if temperature > 0.0: - probs = torch.softmax(logits_cfg / temperature, dim=-1) - next_token_ids = sample_top_p(probs, top_p).flatten() - else: - next_token_ids = torch.argmax(logits_cfg, dim=-1).flatten() - - next_token_ids = next_token_ids.repeat_interleave(2) - update_seq_ids_( - seq=seq, - idx=idx, - next_token_ids=next_token_ids, - dim_tok_inserted=dim_tok_inserted, - eos_tok_seen=eos_tok_seen, - max_len=total_len, - force_end=force_end, - tokenizer=tokenizer, - ) - - if all(seen_eos is True for seen_eos in eos_tok_seen): - break - - decoded_results = [tokenizer.decode(s) for s in seq.tolist()][::2] - decoded_results = [ - ( - res[: res.index(tokenizer.eos_tok) + 1] - if tokenizer.eos_tok in res - else res - ) - for res in decoded_results - ] - - return decoded_results - - -def sample_top_p(probs, p): - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token - - -def get_inference_prompt( - tokenizer: InferenceAbsTokenizer, - midi_dict: MidiDict, - truncate_len: int, - guidance_start_ms: int, - guidance_end_ms: int, - guidance_midi_dict: MidiDict | None = None, -): - assert tokenizer.name == "inference_abs" - - if guidance_midi_dict is not None: - assert guidance_start_ms is not None and guidance_start_ms >= 0 - assert guidance_end_ms is not None and guidance_end_ms >= 0 - assert ( - tokenizer._config["guidance"]["min_ms"] - <= guidance_end_ms - guidance_start_ms - <= tokenizer._config["guidance"]["max_ms"] - ) - - prompt_seq = tokenizer.tokenize( - midi_dict=midi_dict, - prompt_intervals_ms=( - [[0, truncate_len * 1e3]] if truncate_len > 0 else [] - ), - guidance_midi_dict=guidance_midi_dict, - guidance_start_ms=guidance_start_ms, - guidance_end_ms=guidance_end_ms, - ) - - if tokenizer.prompt_end_tok in prompt_seq: - prompt_seq = prompt_seq[ - : prompt_seq.index(tokenizer.prompt_end_tok) + 1 - ] - else: - print("No notes found in prompt region") - prompt_seq = prompt_seq[: prompt_seq.index(tokenizer.bos_tok) + 1] - - if tokenizer.dim_tok in prompt_seq: - prompt_seq.remove(tokenizer.dim_tok) - - if ( - guidance_midi_dict is not None - and tokenizer.guidance_start_tok in prompt_seq - ): - guidance_seq = copy.deepcopy(prompt_seq) - guidance_seq = guidance_seq[ - : guidance_seq.index(tokenizer.guidance_end_tok) - ] - guidance_seq[0] = ("prefix", "instrument", "piano") - else: - guidance_seq = None - - return prompt_seq, guidance_seq diff --git a/aria/tokenizer.py b/aria/tokenizer.py deleted file mode 100644 index c142405d..00000000 --- a/aria/tokenizer.py +++ /dev/null @@ -1,303 +0,0 @@ -"""Tokenizer for MIDI conditioned completions""" - -import copy -import random -import functools - -from typing import Callable - -from aria.config import load_config -from ariautils.midi import MidiDict -from ariautils.tokenizer import AbsTokenizer as _AbsTokenizer - - -class InferenceAbsTokenizer(_AbsTokenizer): - def __init__(self): - super().__init__() - - self.name = "inference_abs" - self._config = load_config()["tokenizer"]["inference_abs"] - - self.prompt_start_tok = "" - self.prompt_end_tok = "" - self.guidance_start_tok = "" - self.guidance_end_tok = "" - - self.add_tokens_to_vocab( - [ - self.prompt_start_tok, - self.prompt_end_tok, - self.guidance_start_tok, - self.guidance_end_tok, - ] - ) - self.special_tokens.append(self.prompt_start_tok) - self.special_tokens.append(self.prompt_end_tok) - self.special_tokens.append(self.guidance_start_tok) - self.special_tokens.append(self.guidance_end_tok) - - def _get_guidance_interval_ms(self, guidance_midi_dict: MidiDict): - first_note_onset_ms = guidance_midi_dict.tick_to_ms( - guidance_midi_dict.note_msgs[0]["tick"] - ) - last_note_onset_ms = guidance_midi_dict.tick_to_ms( - guidance_midi_dict.note_msgs[-1]["tick"] - ) - guidance_segment_length_ms = random.randint( - self._config["guidance"]["min_ms"], - min(self._config["guidance"]["max_ms"], last_note_onset_ms), - ) - guidance_start_ms = random.randint( - first_note_onset_ms, - last_note_onset_ms - guidance_segment_length_ms, - ) - guidance_end_ms = guidance_start_ms + guidance_segment_length_ms - - return guidance_start_ms, guidance_end_ms - - def _get_guidance_seq( - self, - guidance_midi_dict: MidiDict, - guidance_start_ms: int | None = None, - guidance_end_ms: int | None = None, - ): - assert guidance_midi_dict.note_msgs is not None - - # Need to validate these numbers - if guidance_start_ms is None: - assert guidance_end_ms is None - guidance_start_ms, guidance_end_ms = self._get_guidance_interval_ms( - guidance_midi_dict=guidance_midi_dict - ) - - slice_note_msgs = [] - for note_msg in guidance_midi_dict.note_msgs: - start_ms = guidance_midi_dict.tick_to_ms(note_msg["data"]["start"]) - if guidance_start_ms <= start_ms <= guidance_end_ms: - slice_note_msgs.append(note_msg) - - slice_midi_dict = copy.deepcopy(guidance_midi_dict) - slice_midi_dict.note_msgs = slice_note_msgs - - if len(slice_midi_dict.note_msgs) == 0: - # Catches not note in interval - return [] - - guidance_seq = self._tokenize_midi_dict( - midi_dict=slice_midi_dict, - remove_preceding_silence=True, - ) - - if self.dim_tok in guidance_seq: - guidance_seq.remove(self.dim_tok) - - guidance_seq = guidance_seq[ - guidance_seq.index(self.bos_tok) - + 1 : guidance_seq.index(self.eos_tok) - ] - - return ( - [self.guidance_start_tok] + guidance_seq + [self.guidance_end_tok] - ) - - def _add_prompt_tokens( - self, seq: list, prompt_start_ms: int, prompt_end_ms: int - ): - res = copy.deepcopy(seq) - prompt_tok_inserted = False - time_tok_cnt = 0 - curr_time_ms = 0 - for idx, (tok_1, tok_2) in enumerate(zip(seq, seq[1:])): - if tok_1 == self.time_tok: - time_tok_cnt += 1 - elif isinstance(tok_1, tuple) and tok_1[0] in self.instruments_wd: - assert isinstance(tok_2, tuple) and tok_2[0] == "onset" - - # Adjust time - curr_time_ms = (self.abs_time_step_ms * time_tok_cnt) + tok_2[1] - - if ( - curr_time_ms >= prompt_start_ms - and prompt_tok_inserted == False - ): - res.insert(idx, self.prompt_start_tok) - prompt_tok_inserted = True - elif ( - curr_time_ms > prompt_end_ms and prompt_tok_inserted == True - ): - res.insert(idx + 1, self.prompt_end_tok) - break - - return res - - def tokenize( - self, - midi_dict: MidiDict, - prompt_intervals_ms: list[tuple[int, int]], - guidance_midi_dict: MidiDict | None = None, - guidance_start_ms: int | None = None, - guidance_end_ms: int | None = None, - ): - seq = self._tokenize_midi_dict( - midi_dict=midi_dict, remove_preceding_silence=True - ) - first_note_ms = midi_dict.tick_to_ms( - midi_dict.note_msgs[0]["data"]["start"] - ) - - for prompt_start_ms, prompt_end_ms in prompt_intervals_ms: - if prompt_end_ms > first_note_ms: - seq = self._add_prompt_tokens( - seq, - prompt_start_ms=prompt_start_ms - first_note_ms, - prompt_end_ms=prompt_end_ms - first_note_ms, - ) - - if guidance_midi_dict is not None: - guidance_seq = self._get_guidance_seq( - guidance_midi_dict=guidance_midi_dict, - guidance_start_ms=guidance_start_ms, - guidance_end_ms=guidance_end_ms, - ) - else: - guidance_seq = [] - - return guidance_seq + seq - - def detokenize(self, tokenized_seq: list, **kwargs): - if self.guidance_end_tok in tokenized_seq: - seq = tokenized_seq[tokenized_seq.index(self.guidance_end_tok) :] - else: - seq = tokenized_seq - - return super()._detokenize_midi_dict(seq, **kwargs) - - def export_data_aug(self): - return [ - self.export_guidance_tempo_aug(max_tempo_aug=0.2, mixup=True), - self.export_guidance_pitch_aug(3), - self.export_guidance_velocity_aug(2), - ] - - def export_guidance_aug_fn(self, aug_fn): - """Transforms augmentation function to only apply to guidance seq""" - - def _guidance_seq_aug_fn( - src: list, - _aug_fn: Callable, - pad_tok: str, - **kwargs, - ) -> list: - - initial_seq_len = len(src) - if self.guidance_start_tok in src and self.guidance_end_tok in src: - guidance_seq = src[ - src.index(self.guidance_start_tok) - + 1 : src.index(self.guidance_end_tok) - ] - seq = src[src.index(self.guidance_end_tok) + 1 :] - - if len(guidance_seq) == 0: - return src - else: - return src - - augmented_guidance_seq = _aug_fn(guidance_seq) - res = ( - [self.guidance_start_tok] - + augmented_guidance_seq - + [self.guidance_end_tok] - + seq - ) - - # Pad or truncate to original sequence length as necessary - res = res[:initial_seq_len] - res += [pad_tok] * (initial_seq_len - len(res)) - - return res - - return functools.partial( - _guidance_seq_aug_fn, - _aug_fn=aug_fn, - pad_tok=self.pad_tok, - ) - - def export_guidance_pitch_aug(self, max_pitch_aug: int): - """Apply pitch augmentation to the guidance sequence""" - - return self.export_guidance_aug_fn( - self.export_pitch_aug(max_pitch_aug=max_pitch_aug) - ) - - def export_guidance_velocity_aug(self, max_num_aug_steps: int): - """Apply velocity augmentation to the guidance sequence""" - - return self.export_guidance_aug_fn( - self.export_velocity_aug(max_num_aug_steps=max_num_aug_steps) - ) - - def export_guidance_tempo_aug(self, max_tempo_aug: int, mixup: bool): - """Apply tempo augmentation to the guidance sequence""" - - return self.export_guidance_aug_fn( - self.export_tempo_aug(max_tempo_aug=max_tempo_aug, mixup=mixup) - ) - - def split(self, seq: list, seq_len: int): - def _process_chunk(_chunk: list): - # Ensure first token is note token - while True: - if _chunk[0] == self.bos_tok: - break - elif ( - isinstance(_chunk[0], tuple) - and _chunk[0][0] in self.instruments_wd - ): - break - else: - _chunk.pop(0) - - # Insert prompt_start_tok if it is missing (but required) - for idx in range(len(_chunk)): - tok = _chunk[idx] - - if tok == self.prompt_start_tok: - break - elif tok == self.prompt_end_tok: - if _chunk[0] == self.bos_tok: - _chunk.insert(1, self.prompt_start_tok) - else: - _chunk.insert(0, self.prompt_start_tok) - break - - return _chunk - - guidance = [] - if self.guidance_start_tok in seq: - guidance_start = seq.index(self.guidance_start_tok) - guidance_end = seq.index(self.guidance_end_tok) - guidance = seq[guidance_start : guidance_end + 1] - seq = seq[guidance_end + 1 :] - - prefix = [] - while seq: - tok = seq[0] - if tok != self.bos_tok and tok[0] == "prefix": - prefix.append(seq.pop(0)) - else: - break - - chunks = [ - _process_chunk(seq[idx : idx + seq_len]) - for idx in range(0, len(seq) - 100, seq_len) - ] - - res = [] - for chunk in chunks: - sub_seq = guidance + prefix + chunk - sub_seq = sub_seq[:seq_len] - sub_seq += [self.pad_tok] * (seq_len - len(sub_seq)) - - res.append(sub_seq) - - return res diff --git a/aria/training/__init__.py b/aria/training/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aria/training/classifier_finetune.py b/aria/training/classifier_finetune.py new file mode 100644 index 00000000..b8d4b9b2 --- /dev/null +++ b/aria/training/classifier_finetune.py @@ -0,0 +1,728 @@ +import torch +import os +import mmap +import argparse +import logging +import accelerate +import json + +from aria.config import load_model_config +from aria.utils import _load_weight +from ariautils.tokenizer import AbsTokenizer +from aria.model import TransformerCL, ModelConfig + +from torch import nn +from torch.utils.data import DataLoader, Dataset + +from accelerate.logging import get_logger +from typing import Callable +from logging.handlers import RotatingFileHandler +from tqdm import tqdm + +CATEGORY_TAGS = { + "genre": { + "classical": 0, + "jazz": 1, + }, + "music_period": { + "baroque": 0, + "classical": 1, + "romantic": 2, + "impressionist": 3, + }, + "composer": { + "beethoven": 0, + "debussy": 1, + "brahms": 2, + "rachmaninoff": 3, + "schumann": 4, + "mozart": 5, + "liszt": 6, + "bach": 7, + "chopin": 8, + "schubert": 9, + }, + "form": { + "nocturne": 0, + "sonata": 1, + "improvisation": 2, + "etude": 3, + "fugue": 4, + "waltz": 5, + }, + "pianist": { + "hisaishi": 0, + "hancock": 1, + "bethel": 2, + "einaudi": 3, + "clayderman": 4, + "ryuichi": 5, + "yiruma": 6, + "hillsong": 7, + }, + "emotion": { + "happy": 0, + "sad": 1, + "calm": 2, + "tense": 3, + }, +} + + +def setup_logger(project_dir: str): + # Get logger and reset all handlers + logger = logging.getLogger(__name__) + for h in logger.handlers[:]: + logger.removeHandler(h) + + logger.propagate = False + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s", + ) + fh = RotatingFileHandler( + os.path.join(project_dir, "logs.txt"), backupCount=5, maxBytes=1024**3 + ) + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return get_logger(__name__) + + +def setup_project_dir(project_dir: str | None): + if not project_dir: + # Create project directory + if not os.path.isdir("./experiments"): + os.mkdir("./experiments") + + project_dirs = [ + _dir + for _dir in os.listdir("./experiments") + if os.path.isdir(os.path.join("experiments", _dir)) + ] + + ind = 0 + while True: + if str(ind) not in project_dirs: + break + else: + ind += 1 + + project_dir_abs = os.path.abspath(os.path.join("experiments", str(ind))) + assert not os.path.isdir(project_dir_abs) + os.mkdir(project_dir_abs) + + elif project_dir: + if os.path.isdir(project_dir): + assert ( + len(os.listdir(project_dir)) == 0 + ), "Provided project directory is not empty" + project_dir_abs = os.path.abspath(project_dir) + elif os.path.isfile(project_dir): + raise FileExistsError( + "The provided path points toward an existing file" + ) + else: + try: + os.mkdir(project_dir) + except Exception as e: + raise Exception( + f"Failed to create project directory at {project_dir}" + ) from e + + project_dir_abs = os.path.abspath(project_dir) + + os.mkdir(os.path.join(project_dir_abs, "checkpoints")) + + return project_dir_abs + + +class FinetuningDataset(Dataset): + def __init__( + self, + load_path: str, + tag_to_id: dict, + metadata_category: str, + max_seq_len: int, + per_file: bool = False, + ): + self.load_path = load_path + self.tag_to_id = tag_to_id + self.metadata_category = metadata_category + self.max_seq_len = max_seq_len + self.per_file = per_file + self._transform = None + self.tokenizer = AbsTokenizer() + self.index = [] + + assert metadata_category in CATEGORY_TAGS.keys() + assert all( + tag_to_id[_t] == _id + for _t, _id in CATEGORY_TAGS[metadata_category].items() + ) + + self.file_buff = open(self.load_path, "rb") + self.mmap_obj = mmap.mmap( + self.file_buff.fileno(), 0, access=mmap.ACCESS_READ + ) + + while True: + pos = self.mmap_obj.tell() + line = self.mmap_obj.readline() + if not line: + break + self.index.append(pos) + + def set_transform(self, transform: Callable | list[Callable]): + if isinstance(transform, Callable): + self._transform = transform + elif isinstance(transform, list): + # Check validity + for fn in transform: + assert isinstance(fn, Callable), "Invalid function" + + # Define new transformation function (apply fn in order) + def _new_transform(x): + for fn in transform: + x = fn(x) + return x + + self._transform = _new_transform + else: + raise ValueError("Must provide function or list of functions.") + + def __getitem__(self, idx: int): + def _format(tok): + # Required because json formats tuples into lists + if isinstance(tok, list): + return tuple(tok) + return tok + + pos = self.index[idx] + self.mmap_obj.seek(pos) + raw_data = self.mmap_obj.readline().decode("utf-8") + json_data = json.loads(raw_data) + + metadata = json_data["metadata"] + tag = metadata[self.metadata_category] + + assert tag in self.tag_to_id, metadata + tag_tensor = torch.tensor(self.tag_to_id[tag]) + + if self.per_file: + seq_list = json_data["seqs"] + else: + seq_list = [json_data["seq"]] + + seq_tensors = [] + pos_tensors = [] + for seq in seq_list: + seq = [_format(tok) for tok in seq] + + if self._transform: + seq = self._transform(seq) + + seq = seq[: self.max_seq_len] + if self.tokenizer.eos_tok not in seq: + assert self._transform is not None + seq[-1] = self.tokenizer.eos_tok + + eos_index = seq.index(self.tokenizer.eos_tok) + pos_tensor = torch.tensor(eos_index) + + assert len(seq) <= self.max_seq_len + + seq = seq + [self.tokenizer.pad_tok] * (self.max_seq_len - len(seq)) + encoded_seq = self.tokenizer.encode(seq) + seq_tensor = torch.tensor(encoded_seq) + + assert seq_tensor[pos_tensor.item()].item() == 1 # EOS ID check + + seq_tensors.append(seq_tensor) + pos_tensors.append(pos_tensor) + + seq_tensor = torch.stack(seq_tensors) + pos_tensor = torch.stack(pos_tensors) + + return seq_tensor, pos_tensor, tag_tensor + + def __len__(self): + return len(self.index) + + @classmethod + def export_worker_init_fn(cls): + def worker_init_fn(worker_id: int): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + + if hasattr(dataset, "mmap_obj") and dataset.mmap_obj: + dataset.mmap_obj.close() + + f = open(dataset.load_path, "rb") + dataset.mmap_obj = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + + return worker_init_fn + + +def _get_optim( + lr: float, + model: nn.Module, + num_epochs: int, + steps_per_epoch: int, + warmup: int = 100, + end_ratio: float = 0.1, +): + optimizer = torch.optim.AdamW( + model.parameters(), + lr=lr, + weight_decay=0.1, + betas=(0.9, 0.95), + eps=1e-5, + ) + + total_steps = num_epochs * steps_per_epoch + + if warmup > 0: + warmup_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.000001, + end_factor=1, + total_iters=warmup, + ) + linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1, + end_factor=end_ratio, + total_iters=total_steps - warmup, + ) + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_lrs, linear_decay_lrs], + milestones=[warmup], + ) + else: + lr_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1, + end_factor=end_ratio, + total_iters=total_steps, + ) + + return optimizer, lr_scheduler + + +def get_optim( + model: nn.Module, + num_epochs: int, + steps_per_epoch: int, +): + LR = 1e-5 + END_RATIO = 0.1 + WARMUP_STEPS = 0 + + return _get_optim( + lr=LR, + model=model, + num_epochs=num_epochs, + steps_per_epoch=steps_per_epoch, + warmup=WARMUP_STEPS, + end_ratio=END_RATIO, + ) + + +def get_dataloaders( + train_data_path: str, + val_data_path: str, + metadata_category: str, + tag_to_id: dict, + batch_size: int, + num_workers: int, + apply_aug: bool = False, + max_seq_len: int = 1024, +): + train_dataset = FinetuningDataset( + load_path=train_data_path, + tag_to_id=tag_to_id, + metadata_category=metadata_category, + max_seq_len=max_seq_len, + ) + val_dataset = FinetuningDataset( + load_path=val_data_path, + tag_to_id=tag_to_id, + metadata_category=metadata_category, + max_seq_len=max_seq_len, + per_file=True, + ) + + if apply_aug: + print("Applying dataset augmentation") + train_dataset.set_transform(AbsTokenizer().export_data_aug()) + + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + worker_init_fn=FinetuningDataset.export_worker_init_fn(), + ) + val_loader = DataLoader( + val_dataset, + batch_size=1, + shuffle=False, + num_workers=num_workers, + worker_init_fn=FinetuningDataset.export_worker_init_fn(), + ) + + return train_loader, val_loader + + +def _train( + num_epochs: int, + accelerator: accelerate.Accelerator, + model: TransformerCL, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + optimizer: torch.optim.Optimizer, + tag_to_id: dict, + scheduler: torch.optim.lr_scheduler.LRScheduler = None, + project_dir: str | None = None, +): + def make_checkpoint( + _accelerator: accelerate.Accelerator, _epoch: int, _step: int + ): + if accelerator.is_main_process: + checkpoint_dir = os.path.join( + project_dir, + "checkpoints", + f"epoch{_epoch}_step{_step}", + ) + + logger.info( + f"EPOCH {_epoch}/{num_epochs}: Saving checkpoint - {checkpoint_dir}" + ) + _accelerator.save_state(checkpoint_dir) + + def train_loop(dataloader: DataLoader, _epoch: int): + loss = torch.tensor([0.0]) + avg_train_loss = 0 + trailing_loss = 0 + loss_buffer = [] + + try: + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + except Exception: + lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) + + model.train() + for __step, batch in ( + pbar := tqdm( + enumerate(dataloader), + total=len(dataloader), + initial=0, + leave=False, + ) + ): + pbar.set_postfix_str( + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing={round(trailing_loss, 4)}" + ) + + with accelerator.accumulate(model): + step = __step + 1 + + seqs, eos_pos, labels = batch + seqs = seqs.squeeze(1) + eos_pos = eos_pos.squeeze(1) + + logits = model(seqs) # (b_sz, s_len, class_size) + logits = logits[ + torch.arange(logits.shape[0], device=logits.device), eos_pos + ] + loss = loss_fn(logits, labels) + + # Calculate statistics + loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) + trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( + loss_buffer[-TRAILING_LOSS_STEPS:] + ) + avg_train_loss = sum(loss_buffer) / len(loss_buffer) + + # Logging + logger.debug( + f"EPOCH {_epoch} STEP {step}: " + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing_loss={round(trailing_loss, 4)}, " + f"average_loss={round(avg_train_loss, 4)}" + ) + + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + if scheduler: + scheduler.step() + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + + return avg_train_loss + + def val_loop(dataloader: DataLoader, _epoch: int, tag_to_id: dict): + model.eval() + pad_id = AbsTokenizer().pad_id + preds = [] + labels = [] + + with torch.inference_mode(): + pbar = tqdm( + dataloader, desc=f"Validation Epoch {_epoch}", leave=False + ) + for batch in pbar: + seqs, pos, tag = batch + seqs = seqs.squeeze(0) # (n, max_seq_len) + pos = pos.squeeze(0) # (n,) + + logits = model(seqs) # (n, seq_len, class_size) + logits = logits[ + torch.arange(logits.shape[0], device=logits.device), pos + ] + probs = torch.softmax(logits, dim=-1) # (n, class_size) + + non_pad_counts = ( + (seqs != pad_id).sum(dim=1, keepdim=True).float() + ) + weighted_probs = probs * non_pad_counts + aggregated_probs = weighted_probs.sum(dim=0) + predicted_label = aggregated_probs.argmax().item() + + preds.append(predicted_label) + labels.append(tag.item()) + + tmp_acc = sum(p == t for p, t in zip(preds, labels)) / len( + preds + ) + pbar.set_postfix_str(f"acc={round(tmp_acc, 4)}") + + accuracy = sum(p == t for p, t in zip(preds, labels)) / len(labels) + + # Compute per-class F1 scores + id_to_tag = {v: k for k, v in tag_to_id.items()} + # Initialize counts per class + metrics = {tag: {"TP": 0, "FP": 0, "FN": 0} for tag in tag_to_id.keys()} + for true_id, pred_id in zip(labels, preds): + true_tag = id_to_tag[true_id] + pred_tag = id_to_tag[pred_id] + if true_id == pred_id: + metrics[true_tag]["TP"] += 1 + else: + metrics[true_tag]["FN"] += 1 + metrics[pred_tag]["FP"] += 1 + + class_metrics = {} + f1_scores = [] + for tag, counts in metrics.items(): + TP = counts["TP"] + FP = counts["FP"] + FN = counts["FN"] + precision = TP / (TP + FP) if (TP + FP) > 0 else 0 + recall = TP / (TP + FN) if (TP + FN) > 0 else 0 + f1 = ( + (2 * precision * recall / (precision + recall)) + if (precision + recall) > 0 + else 0 + ) + class_metrics[tag] = { + "precision": precision, + "recall": recall, + "F1": f1, + } + f1_scores.append(f1) + + macro_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0 + + logger.info( + f"Validation Epoch {_epoch}: accuracy={round(accuracy, 4)}, macro-F1={round(macro_f1, 4)}" + ) + logger.info(f"Class metrics: {class_metrics}") + + return accuracy, macro_f1, class_metrics + + logger = get_logger(__name__) + loss_fn = nn.CrossEntropyLoss() + TRAILING_LOSS_STEPS = 20 + + epoch_metrics = [] + for __epoch in range(num_epochs): + train_loop(dataloader=train_dataloader, _epoch=__epoch) + acc, macro_f1, class_metrics = val_loop( + dataloader=val_dataloader, _epoch=__epoch, tag_to_id=tag_to_id + ) + epoch_metrics.append( + { + "accuracy": acc, + "macro_f1": macro_f1, + "class_metrics": class_metrics, + } + ) + + return epoch_metrics + + +def train( + model_name: str, + metadata_category: str, + apply_aug: bool, + train_data_path: str, + val_data_path: str, + num_workers: int, + num_epochs: int, + batch_size: int, + grad_acc_steps: int, + project_dir: str | None = None, + checkpoint_path: str | None = None, + dataset_size: int | None = None, +): + accelerator = accelerate.Accelerator( + project_dir=project_dir, + gradient_accumulation_steps=grad_acc_steps, + ) + + tag_to_id = CATEGORY_TAGS[metadata_category] + + if accelerator.is_main_process: + project_dir = setup_project_dir(project_dir) + logger = setup_logger(os.path.join(project_dir)) + else: + # In other processes, we won't create logs + project_dir = project_dir or "./experiments" + logger = get_logger(__name__) + + logger.info(f"Project directory: {project_dir}") + logger.info(f"Metadata category: {metadata_category}") + logger.info(f"Dataset size: {dataset_size}") + logger.info(f"Applying aug: {apply_aug}") + logger.info( + f"Training config:epochs={num_epochs}, batch_size={batch_size}, num_workers={num_workers}" + ) + + tokenizer = AbsTokenizer() + model_config = ModelConfig(**load_model_config(model_name)) + model_config.set_vocab_size(tokenizer.vocab_size) + model = TransformerCL(model_config) + + assert model_config.class_size == len(tag_to_id.keys()) + + if checkpoint_path is not None: + logger.info(f"Loading checkpoint from {checkpoint_path}") + model_state = _load_weight(checkpoint_path) + model_state = { + k.replace("_orig_mod.", ""): v for k, v in model_state.items() + } + model.load_state_dict(model_state, strict=False) + torch.nn.init.normal_( + model.model.tok_embeddings.weight.data[1:2], mean=0.0, std=0.02 + ) # Re-init EOS tok + + else: + logger.info("No checkpoint path provided") + + model.compile() + + train_dataloader, val_dataloader = get_dataloaders( + train_data_path=train_data_path, + val_data_path=val_data_path, + metadata_category=metadata_category, + tag_to_id=tag_to_id, + batch_size=batch_size, + num_workers=num_workers, + apply_aug=apply_aug, + ) + + optimizer, scheduler = get_optim( + model=model, + num_epochs=num_epochs, + steps_per_epoch=len(train_dataloader), + ) + + ( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) = accelerator.prepare( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) + + epoch_metrics = _train( + num_epochs=num_epochs, + accelerator=accelerator, + model=model, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + optimizer=optimizer, + tag_to_id=tag_to_id, + scheduler=scheduler, + project_dir=project_dir, + ) + + max_accuracy = ( + max(metric["accuracy"] for metric in epoch_metrics) + if epoch_metrics + else 0.0 + ) + logger.info(f"Max accuracy: {max_accuracy}") + results = { + "metadata_category": metadata_category, + "dataset_size": dataset_size, + "epoch_metrics": epoch_metrics, + "max_accuracy": max_accuracy, + } + with open(os.path.join(project_dir, "results.json"), "w") as f: + json.dump(results, f, indent=4) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Finetune a model for classification." + ) + parser.add_argument("--model_name", type=str, required=True) + parser.add_argument("--metadata_category", type=str, required=True) + parser.add_argument("--dataset_size", type=int, required=False) + parser.add_argument("--apply_aug", action="store_true") + parser.add_argument("--checkpoint_path", type=str, default=None) + parser.add_argument("--train_data_path", type=str, required=True) + parser.add_argument("--val_data_path", type=str, required=True) + parser.add_argument("--batch_size", type=int) + parser.add_argument("--num_epochs", type=int) + parser.add_argument("--num_workers", type=int, default=1) + parser.add_argument("--grad_acc_steps", type=int, default=1) + parser.add_argument("--project_dir", type=str, default=None) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + train( + model_name=args.model_name, + metadata_category=args.metadata_category, + dataset_size=args.dataset_size, + apply_aug=args.apply_aug, + checkpoint_path=args.checkpoint_path, + train_data_path=args.train_data_path, + val_data_path=args.val_data_path, + batch_size=args.batch_size, + num_epochs=args.num_epochs, + num_workers=args.num_workers, + grad_acc_steps=args.grad_acc_steps, + project_dir=args.project_dir, + ) diff --git a/aria/training/contrastive_finetune.py b/aria/training/contrastive_finetune.py new file mode 100644 index 00000000..2ac32eaa --- /dev/null +++ b/aria/training/contrastive_finetune.py @@ -0,0 +1,660 @@ +import torch +import os +import mmap +import argparse +import logging +import random +import copy +import accelerate +import json + +from aria.config import load_model_config +from aria.utils import _load_weight +from ariautils.tokenizer import AbsTokenizer +from ariautils.midi import MidiDict +from aria.model import TransformerEMB, ModelConfig + +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset + +from accelerate.logging import get_logger +from logging.handlers import RotatingFileHandler +from tqdm import tqdm + + +def setup_logger(project_dir: str): + logger = logging.getLogger(__name__) + for h in logger.handlers[:]: + logger.removeHandler(h) + + logger.propagate = False + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s", + ) + fh = RotatingFileHandler( + os.path.join(project_dir, "logs.txt"), backupCount=5, maxBytes=1024**3 + ) + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return get_logger(__name__) + + +def setup_project_dir(project_dir: str | None): + if not project_dir: + # Create project directory + if not os.path.isdir("./experiments"): + os.mkdir("./experiments") + + project_dirs = [ + _dir + for _dir in os.listdir("./experiments") + if os.path.isdir(os.path.join("experiments", _dir)) + ] + + ind = 0 + while True: + if str(ind) not in project_dirs: + break + else: + ind += 1 + + project_dir_abs = os.path.abspath(os.path.join("experiments", str(ind))) + assert not os.path.isdir(project_dir_abs) + os.mkdir(project_dir_abs) + + elif project_dir: + if os.path.isdir(project_dir): + assert ( + len(os.listdir(project_dir)) == 0 + ), "Provided project directory is not empty" + project_dir_abs = os.path.abspath(project_dir) + elif os.path.isfile(project_dir): + raise FileExistsError( + "The provided path points toward an existing file" + ) + else: + try: + os.mkdir(project_dir) + except Exception as e: + raise Exception( + f"Failed to create project directory at {project_dir}" + ) from e + + project_dir_abs = os.path.abspath(project_dir) + + os.mkdir(os.path.join(project_dir_abs, "checkpoints")) + + return project_dir_abs + + +class ContrastiveDataset(Dataset): + def __init__( + self, + load_path: str, + min_number_slice_notes: int, + max_number_slice_notes: int, + max_seq_len: int, + apply_aug: bool = False, + ): + self.load_path = load_path + self.min_number_slice_notes = min_number_slice_notes + self.max_number_slice_notes = max_number_slice_notes + self.max_seq_len = max_seq_len + self.apply_aug = apply_aug + + self.tokenizer = AbsTokenizer() + + if apply_aug is True: + self.aug_fns = self.tokenizer.export_data_aug() + else: + self.aug_fns = None + + self.index = [] + self.file_buff = open(self.load_path, "rb") + self.mmap_obj = mmap.mmap( + self.file_buff.fileno(), 0, access=mmap.ACCESS_READ + ) + + while True: + pos = self.mmap_obj.tell() + line = self.mmap_obj.readline() + if not line: + break + self.index.append(pos) + + def get_slice( + self, + midi_dict: MidiDict, + min_num_notes: int, + max_num_notes: int, + max_seq_len: int, + apply_aug: bool = False, + ): + _midi_dict = copy.deepcopy(midi_dict) + slice_length = random.randint(min_num_notes, max_num_notes) + if len(_midi_dict.note_msgs) <= min_num_notes: + idx = 0 + else: + idx = random.randint(0, len(_midi_dict.note_msgs) - min_num_notes) + + _midi_dict.note_msgs = _midi_dict.note_msgs[idx : idx + slice_length] + _midi_dict.metadata = {} + + tokenized_slice = self.tokenizer.tokenize(_midi_dict) + + if apply_aug: + assert self.aug_fns + for fn in self.aug_fns: + tokenized_slice = fn(tokenized_slice) + + while self.tokenizer.pad_tok in tokenized_slice: + tokenized_slice.remove(self.tokenizer.pad_tok) + + if self.tokenizer.dim_tok in tokenized_slice: + tokenized_slice.remove(self.tokenizer.dim_tok) + + # Use EOS tok for classification head + tokenized_slice = tokenized_slice[:max_seq_len] + tokenized_slice += [self.tokenizer.pad_tok] * ( + max_seq_len - len(tokenized_slice) + ) + if self.tokenizer.eos_tok not in tokenized_slice: + tokenized_slice[-1] = self.tokenizer.eos_tok + + pos = tokenized_slice.index(self.tokenizer.eos_tok) + + return tokenized_slice, pos + + def __getitem__(self, idx: int): + file_pos = self.index[idx] + self.mmap_obj.seek(file_pos) + + raw_data = self.mmap_obj.readline().decode("utf-8") + json_data = json.loads(raw_data) + midi_dict = MidiDict.from_msg_dict(json_data) + + slice_seq_1, slice_pos_1 = self.get_slice( + midi_dict=midi_dict, + min_num_notes=self.min_number_slice_notes, + max_num_notes=self.max_number_slice_notes, + max_seq_len=self.max_seq_len, + apply_aug=self.apply_aug, + ) + slice_seq_2, slice_pos_2 = self.get_slice( + midi_dict=midi_dict, + min_num_notes=self.min_number_slice_notes, + max_num_notes=self.max_number_slice_notes, + max_seq_len=self.max_seq_len, + apply_aug=self.apply_aug, + ) + + assert len(slice_seq_1) <= self.max_seq_len + assert len(slice_seq_2) <= self.max_seq_len + assert slice_pos_1 < self.max_seq_len + assert slice_pos_2 < self.max_seq_len + assert slice_seq_1[slice_pos_1] == self.tokenizer.eos_tok + assert slice_seq_2[slice_pos_2] == self.tokenizer.eos_tok + + slices_enc = torch.tensor( + [ + self.tokenizer.encode(slice_seq_1), + self.tokenizer.encode(slice_seq_2), + ] + ) + + slices_pos = torch.tensor([slice_pos_1, slice_pos_2]) + + return slices_enc, slices_pos + + def __len__(self): + return len(self.index) + + @classmethod + def export_worker_init_fn(cls): + def worker_init_fn(worker_id: int): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + + if hasattr(dataset, "mmap_obj") and dataset.mmap_obj: + dataset.mmap_obj.close() + + dataset.file_buff = open(dataset.load_path, "rb") + dataset.mmap_obj = mmap.mmap( + dataset.file_buff.fileno(), 0, access=mmap.ACCESS_READ + ) + + return worker_init_fn + + +def _get_optim( + lr: float, + model: nn.Module, + num_epochs: int, + steps_per_epoch: int, + warmup: int = 100, + end_ratio: int = 0.1, +): + optimizer = torch.optim.AdamW( + model.parameters(), + lr=lr, + weight_decay=0.1, + betas=(0.9, 0.95), + eps=1e-5, + ) + + warmup_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.000001, + end_factor=1, + total_iters=warmup, + ) + linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1, + end_factor=end_ratio, + total_iters=(num_epochs * steps_per_epoch) - warmup, + ) + + lr_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_lrs, linear_decay_lrs], + milestones=[warmup], + ) + + return optimizer, lr_scheduler + + +def get_optim( + model: nn.Module, + num_epochs: int, + steps_per_epoch: int, +): + LR = 1e-5 + END_RATIO = 0.1 + WARMUP_STEPS = 1000 + + return _get_optim( + lr=LR, + model=model, + num_epochs=num_epochs, + steps_per_epoch=steps_per_epoch, + warmup=WARMUP_STEPS, + end_ratio=END_RATIO, + ) + + +def get_dataloaders( + train_data_path: str, + val_data_path: str, + batch_size: int, + num_workers: int, + min_number_slice_notes: int = 100, + max_number_slice_notes: int = 650, + max_seq_len: int = 2048, +): + train_dataset = ContrastiveDataset( + load_path=train_data_path, + min_number_slice_notes=min_number_slice_notes, + max_number_slice_notes=max_number_slice_notes, + max_seq_len=max_seq_len, + ) + val_dataset = ContrastiveDataset( + load_path=val_data_path, + min_number_slice_notes=min_number_slice_notes, + max_number_slice_notes=max_number_slice_notes, + max_seq_len=max_seq_len, + ) + + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + worker_init_fn=ContrastiveDataset.export_worker_init_fn(), + ) + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + worker_init_fn=ContrastiveDataset.export_worker_init_fn(), + ) + + return train_loader, val_loader + + +# TODO: This might not be 100% correct (verify CEL calculation) +def symmetric_nt_xent_loss_cosine( + z1: torch.Tensor, z2: torch.Tensor, temperature=0.5 +): + bsz = z1.shape[0] + + z1 = F.normalize(z1, dim=1) # First view + z2 = F.normalize(z2, dim=1) # Second view + + sim_matrix = ( + F.cosine_similarity(z1.unsqueeze(1), z2.unsqueeze(0), dim=-1) + / temperature + ) + + labels = torch.arange(bsz, device=z1.device) + + loss1 = F.cross_entropy(sim_matrix, labels) + loss2 = F.cross_entropy(sim_matrix.T, labels) + + return (loss1 + loss2) / 2.0 + + +def _train( + num_epochs: int, + accelerator: accelerate.Accelerator, + model: TransformerEMB, + train_dataloader: DataLoader, + val_dataloader: DataLoader, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler = None, + project_dir: str | None = None, +): + def make_checkpoint( + _accelerator: accelerate.Accelerator, _epoch: int, _step: int + ): + if accelerator.is_main_process: + checkpoint_dir = os.path.join( + project_dir, + "checkpoints", + f"epoch{_epoch}_step{_step}", + ) + + logger.info( + f"EPOCH {_epoch}/{num_epochs}: Saving checkpoint - {checkpoint_dir}" + ) + _accelerator.save_state(checkpoint_dir) + + def train_loop( + dataloader: DataLoader, + _epoch: int, + steps_per_checkpoint: int | None = None, + ): + loss = torch.tensor([0.0]) + avg_train_loss = 0 + trailing_loss = 0 + loss_buffer = [] + + try: + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + except Exception: + pass + else: + lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) + + model.train() + for __step, batch in ( + pbar := tqdm( + enumerate(dataloader), + total=len(dataloader), + initial=0, + leave=False, + ) + ): + pbar.set_postfix_str( + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing={round(trailing_loss, 4)}" + ) + + with accelerator.accumulate(model): + step = __step + 1 + seqs, eos_pos = batch + + seqs = seqs.contiguous() + bsz = seqs.size(0) + seqs_flat = seqs.view(2 * bsz, seqs.size(-1)) + + outputs = model(seqs_flat) + z1_full = outputs[0::2] + z2_full = outputs[1::2] + + batch_indices = torch.arange(bsz, device=z1_full.device) + eos_pos_1 = eos_pos[:, 0] + eos_pos_2 = eos_pos[:, 1] + + z1 = z1_full[batch_indices, eos_pos_1] + z2 = z2_full[batch_indices, eos_pos_2] + + loss = symmetric_nt_xent_loss_cosine(z1, z2) + + # Calculate statistics + loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) + trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( + loss_buffer[-TRAILING_LOSS_STEPS:] + ) + avg_train_loss = sum(loss_buffer) / len(loss_buffer) + + # Logging + logger.debug( + f"EPOCH {_epoch} STEP {step}: " + f"lr={lr_for_print}, " + f"loss={round(loss.item(), 4)}, " + f"trailing_loss={round(trailing_loss, 4)}, " + f"average_loss={round(avg_train_loss, 4)}" + ) + + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + if scheduler: + scheduler.step() + lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) + + if steps_per_checkpoint: + if step % steps_per_checkpoint == 0: + make_checkpoint( + _accelerator=accelerator, + _epoch=_epoch, + _step=step, + ) + + return avg_train_loss + + def val_loop(dataloader: DataLoader, _epoch: int): + model.eval() + val_loss_buffer = [] + + with torch.no_grad(): + pbar = tqdm( + dataloader, desc=f"Validation Epoch {_epoch}", leave=False + ) + for batch in pbar: + seqs, eos_pos = batch + + seqs = seqs.contiguous() + bsz = seqs.size(0) + seqs_flat = seqs.view(2 * bsz, seqs.size(-1)) + + outputs = model(seqs_flat) + z1_full = outputs[0::2] + z2_full = outputs[1::2] + + batch_indices = torch.arange(bsz, device=z1_full.device) + eos_pos_1 = eos_pos[:, 0] + eos_pos_2 = eos_pos[:, 1] + + z1 = z1_full[batch_indices, eos_pos_1] + z2 = z2_full[batch_indices, eos_pos_2] + + loss = symmetric_nt_xent_loss_cosine(z1, z2) + # Gather loss from all devices (if applicable) + val_loss_buffer.append( + accelerator.gather(loss).mean(dim=0).item() + ) + + current_avg_loss = sum(val_loss_buffer) / len(val_loss_buffer) + + pbar.set_postfix_str(f"avg_loss={round(current_avg_loss,4)}") + + avg_val_loss = sum(val_loss_buffer) / len(val_loss_buffer) + + logger.info( + f"Validation Epoch {_epoch}: average_loss={round(avg_val_loss, 4)}" + ) + return avg_val_loss + + logger = get_logger(__name__) + TRAILING_LOSS_STEPS = 100 + + for _epoch_num in range(num_epochs): + train_loop(dataloader=train_dataloader, _epoch=_epoch_num) + make_checkpoint( + _accelerator=accelerator, _epoch=_epoch_num + 1, _step=0 + ) + val_loop(dataloader=val_dataloader, _epoch=_epoch_num) + + +def train( + model_name: str, + train_data_path: str, + val_data_path: str, + num_workers: int, + num_epochs: int, + batch_size: int, + grad_acc_steps: int, + project_dir: str | None = None, + checkpoint_path: str | None = None, +): + accelerator = accelerate.Accelerator( + project_dir=project_dir, + gradient_accumulation_steps=grad_acc_steps, + ) + + if accelerator.is_main_process: + project_dir = setup_project_dir(project_dir) + logger = setup_logger(os.path.join(project_dir)) + else: + # In other processes, we won't create logs + project_dir = project_dir or "./experiments" + logger = get_logger(__name__) + + logger.info(f"Project directory: {project_dir}") + logger.info( + f"Training config: epochs={num_epochs}, batch_size={batch_size}, num_workers={num_workers}" + ) + + tokenizer = AbsTokenizer() + model_config = ModelConfig(**load_model_config(model_name)) + model_config.set_vocab_size(tokenizer.vocab_size) + model = TransformerEMB(model_config) + + if checkpoint_path is not None: + logger.info(f"Loading checkpoint from {checkpoint_path}") + model_state = _load_weight(checkpoint_path) + model_state = { + k.replace("_orig_mod.", ""): v for k, v in model_state.items() + } + model.load_state_dict(model_state, strict=False) + else: + logger.info("No checkpoint path provided") + + train_dataloader, val_dataloader = get_dataloaders( + train_data_path=train_data_path, + val_data_path=val_data_path, + batch_size=batch_size, + num_workers=num_workers, + ) + + optimizer, scheduler = get_optim( + model=model, + num_epochs=num_epochs, + steps_per_epoch=len(train_dataloader), + ) + + ( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) = accelerator.prepare( + model, + train_dataloader, + val_dataloader, + optimizer, + scheduler, + ) + + _train( + num_epochs=num_epochs, + accelerator=accelerator, + model=model, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + optimizer=optimizer, + scheduler=scheduler, + project_dir=project_dir, + ) + + +def test_dataset(): + tokenizer = AbsTokenizer() + dataset = ContrastiveDataset( + load_path="/mnt/ssd1/aria/data/mididict-ft_val.jsonl", + min_number_slice_notes=150, + max_number_slice_notes=300, + max_seq_len=1024, + apply_aug=True, + ) + + for idx, (enc, pos) in enumerate(dataset): + seq_1 = enc[0].tolist() + midi_dict_1 = tokenizer.detokenize(tokenizer.decode(seq_1)) + midi_dict_1.to_midi().save("/home/loubb/Dropbox/shared/test1.mid") + + seq_2 = enc[1].tolist() + midi_dict_2 = tokenizer.detokenize(tokenizer.decode(seq_2)) + midi_dict_2.to_midi().save("/home/loubb/Dropbox/shared/test2.mid") + + print(enc.shape) + print(pos.shape, pos) + input("") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Finetune a model contrastive_embeddings" + ) + parser.add_argument("--model_name", type=str, required=True) + parser.add_argument("--checkpoint_path", type=str, default=None) + parser.add_argument("--train_data_path", type=str, required=True) + parser.add_argument("--val_data_path", type=str, required=True) + parser.add_argument("--batch_size", type=int) + parser.add_argument("--num_epochs", type=int) + parser.add_argument("--num_workers", type=int, default=1) + parser.add_argument("--grad_acc_steps", type=int, default=1) + parser.add_argument("--project_dir", type=str, default=None) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + train( + model_name=args.model_name, + checkpoint_path=args.checkpoint_path, + train_data_path=args.train_data_path, + val_data_path=args.val_data_path, + batch_size=args.batch_size, + num_epochs=args.num_epochs, + num_workers=args.num_workers, + grad_acc_steps=args.grad_acc_steps, + project_dir=args.project_dir, + ) + + # test_dataset() diff --git a/aria/train.py b/aria/training/train.py similarity index 83% rename from aria/train.py rename to aria/training/train.py index 5733e494..67001a22 100644 --- a/aria/train.py +++ b/aria/training/train.py @@ -3,14 +3,13 @@ import csv import argparse import logging +import random import torch import accelerate from torch import nn as nn from torch.utils.data import DataLoader -from torch.utils.flop_counter import FlopCounterMode -from triton.testing import do_bench from accelerate.logging import get_logger from safetensors.torch import load_file from logging.handlers import RotatingFileHandler @@ -18,13 +17,11 @@ from typing import List from aria.config import load_model_config -from aria.model import ModelConfig, TransformerLM +from aria.model import ModelConfig, TransformerLM, TransformerLM_CND from ariautils.tokenizer import Tokenizer, AbsTokenizer, RelTokenizer -from aria.tokenizer import InferenceAbsTokenizer from aria.datasets import ( TrainingDataset, PretrainingDataset, - FinetuningDataset, ) from aria.utils import _load_weight @@ -41,25 +38,25 @@ # For example usage you could run the pre-training script with: # # accelerate launch [arguments] aria/train.py train \ -# small \ -# -train_data data/train \ -# -val_data data/val \ -# -epochs 10 \ -# -bs 32 \ -# -workers 8 +# medium \ +# --train_data data/train \ +# --val_data data/val \ +# --epochs 10 \ +# --bs 32 \ +# --workers 8 # # You could resume a run from an accelerate checkpoint with: # # accelerate launch [arguments] aria/train.py resume \ -# small \ -# -train_data data/train \ -# -val_data data/val \ -# -cp_dir models/epoch5_step0 \ -# -r_step 0 \ -# -r_epoch 5 \ -# -epochs 5 \ -# -bs 32 \ -# -workers 8 +# medium \ +# --train_data data/train \ +# --val_data data/val \ +# --cp_dir models/epoch5_step0 \ +# --r_step 0 \ +# --r_epoch 5 \ +# --epochs 5 \ +# --bs 32 \ +# --workers 8 def setup_logger(project_dir: str): @@ -156,7 +153,7 @@ def _get_optim( num_epochs: int, steps_per_epoch: int, warmup: int = 100, - end_ratio: int = 0.1, + end_ratio: float = 0.1, ): optimizer = torch.optim.AdamW( model.parameters(), @@ -213,31 +210,18 @@ def get_dataloaders( tokenizer: Tokenizer, batch_size: int, num_workers: int, + use_embeddings: bool, init_epoch: int | None = None, apply_aug: bool = True, - finetune: bool = False, ): - logger = logging.getLogger(__name__) - if finetune == False: - train_dataset = PretrainingDataset( - dir_paths=train_data_dirs, - tokenizer=tokenizer, - ) - val_dataset = PretrainingDataset( - dir_paths=val_data_dir, - tokenizer=tokenizer, - ) - elif finetune == True: - train_dataset = FinetuningDataset( - dir_paths=train_data_dirs, - tokenizer=tokenizer, - ) - val_dataset = FinetuningDataset( - dir_paths=val_data_dir, - tokenizer=tokenizer, - ) - else: - raise ValueError + train_dataset = PretrainingDataset( + dir_paths=train_data_dirs, + tokenizer=tokenizer, + ) + val_dataset = PretrainingDataset( + dir_paths=val_data_dir, + tokenizer=tokenizer, + ) if init_epoch: train_dataset.init_epoch(idx=init_epoch) @@ -262,6 +246,12 @@ def get_dataloaders( shuffle=False, ) + if use_embeddings is True: + _src, _tgt, _mask, _emb = train_dataset[0] + _src, _tgt, _mask, __emb = val_dataset[0] + assert _emb.numel() != 0, "Embeddings not present in train dataset" + assert __emb.numel() != 0, "Embeddings not present in val dataset" + return train_dataloader, val_dataloader @@ -271,6 +261,7 @@ def _train( model: TransformerLM, train_dataloader: DataLoader, val_dataloader: DataLoader, + use_embeddings: bool, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler = None, steps_per_checkpoint: int | None = None, @@ -278,34 +269,6 @@ def _train( resume_epoch: int | None = None, project_dir: str | None = None, ): - def profile_flops(dataloader: DataLoader): - def _bench(): - for batch in dataloader: - src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz) - logits = model(src) # (b_sz, s_len, v_sz) - logits = logits.transpose(1, 2) - loss = loss_fn(logits, tgt) - - # Backwards step - omit optimizer.step() - accelerator.backward(loss) - optimizer.zero_grad() - break - - logger.info( - f"Model has " - f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " - "parameters" - ) - - # logger.info("Profiling FLOP") - # flop_counter = FlopCounterMode(display=False) - # _bench() - - # with flop_counter: - # _bench() - # total_flop = sum(flop_counter.get_flop_counts()["Global"].values()) - # logger.info(f"Forwards & backwards FLOP: {total_flop / 1e12} TF") - def make_checkpoint( _accelerator: accelerate.Accelerator, _epoch: int, _step: int ): @@ -353,8 +316,19 @@ def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0): with accelerator.accumulate(model): step = __step + _resume_step + 1 - src, tgt, mask = batch # (b_sz, s_len), (b_sz, s_len, v_sz) - logits = model(src) # (b_sz, s_len, v_sz) + src, tgt, mask, emb = ( + batch # (b_sz, s_len), (b_sz, s_len), (b_sz, s_len), (b_sz, d_emb) + ) + + use_embeddings_cond = use_embeddings and (random.random() > 0.5) + + if use_embeddings_cond is True: + logits = model(src=src, emb=emb) # (b_sz, s_len - 1, v_sz) + tgt = tgt[:, :-1] # (b_sz, s_len - 1) + mask = mask[:, :-1] # (b_sz, s_len - 1) + else: + logits = model(src) # (b_sz, s_len, v_sz) + logits = logits.transpose( 1, 2 ) # Transpose for CrossEntropyLoss @@ -418,8 +392,18 @@ def val_loop(dataloader, _epoch: int): leave=False, ) ): - src, tgt, mask = batch # (b_sz, s_len), (b_sz, s_len, v_sz) - logits = model(src) # (b_sz, s_len, v_sz) + src, tgt, mask, emb = ( + batch # (b_sz, s_len), (b_sz, s_len), (b_sz, s_len), (b_sz, d_emb) + ) + use_embeddings_cond = use_embeddings and (random.random() > 0.5) + + if use_embeddings_cond is True: + logits = model(src=src, emb=emb) # (b_sz, s_len - 1, v_sz) + tgt = tgt[:, :-1] # (b_sz, s_len - 1) + mask = mask[:, :-1] # (b_sz, s_len - 1) + else: + logits = model(src) # (b_sz, s_len, v_sz) + logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss loss = loss_fn(logits, tgt) @@ -451,7 +435,12 @@ def val_loop(dataloader, _epoch: int): PAD_ID = train_dataloader.dataset.tokenizer.pad_id logger = get_logger(__name__) # Accelerate logger loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID, reduction="none") - profile_flops(dataloader=train_dataloader) + + logger.info( + f"Model has " + f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " + "parameters" + ) if accelerator.is_main_process: loss_csv = open(os.path.join(project_dir, "loss.csv"), "w") @@ -504,10 +493,12 @@ def val_loop(dataloader, _epoch: int): epoch_csv.close() +# TODO: Add use_embeddings logic to this code path def resume_train( model_name: str, train_data_paths: str, val_data_path: str, + use_embeddings: bool, num_workers: int, batch_size: int, grad_acc_steps: int, @@ -533,8 +524,6 @@ def resume_train( tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path) if tokenizer_name == "abs": tokenizer = AbsTokenizer() - elif tokenizer_name == "inference_abs": - tokenizer = InferenceAbsTokenizer() elif tokenizer_name == "rel": tokenizer = RelTokenizer() else: @@ -560,6 +549,7 @@ def resume_train( logger.info( f"Using training config: " f"model_name={model_name}, " + f"use_embeddings={use_embeddings}, " f"epochs={epochs}, " f"batch_size={batch_size}, " f"grad_acc_steps={grad_acc_steps}, " @@ -575,7 +565,12 @@ def resume_train( # Init model model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) - model = TransformerLM(model_config) + + if use_embeddings: + model = TransformerLM_CND(model_config) + else: + model = TransformerLM(model_config) + model.compile() train_dataloader, val_dataloader = get_dataloaders( @@ -586,6 +581,7 @@ def resume_train( batch_size=batch_size, num_workers=num_workers, apply_aug=True, + use_embeddings=use_embeddings, ) optimizer, scheduler = get_optim( model, @@ -624,6 +620,7 @@ def resume_train( model=model, train_dataloader=train_dataloader, val_dataloader=val_dataloader, + use_embeddings=use_embeddings, optimizer=optimizer, scheduler=scheduler, steps_per_checkpoint=steps_per_checkpoint, @@ -637,6 +634,7 @@ def train( model_name: str, train_data_paths: List[str], val_data_path: str, + use_embeddings: bool, num_workers: int, batch_size: int, grad_acc_steps: int, @@ -659,8 +657,6 @@ def train( tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path) if tokenizer_name == "abs": tokenizer = AbsTokenizer() - elif tokenizer_name == "inference_abs": - tokenizer = InferenceAbsTokenizer() elif tokenizer_name == "rel": tokenizer = RelTokenizer() else: @@ -678,6 +674,7 @@ def train( logger.info( f"Using training config: " f"model_name={model_name}, " + f"use_embeddings={use_embeddings}, " f"checkpoint_path={checkpoint_path}, " if checkpoint_path else "" @@ -693,18 +690,24 @@ def train( # Init model model_config = ModelConfig(**load_model_config(model_name)) model_config.set_vocab_size(tokenizer.vocab_size) - model = TransformerLM(model_config) + + if use_embeddings is True: + model = TransformerLM_CND(model_config) + else: + model = TransformerLM(model_config) + model.compile() logger.info(f"Loaded model with config: {load_model_config(model_name)}") if checkpoint_path: try: model.load_state_dict(_load_weight(checkpoint_path)) - except Exception as e: - raise Exception( - f"Failed to load checkpoint: {e}\n" - "This could be due to a mismatch between the tokenizer used " - "to build the pre-training and fine-tuning datasets" + except RuntimeError as e: + print(e) + logger.info( + f"Failed to load {model_name} into {model_name}, attempting with strict=False" ) + model.load_state_dict(_load_weight(checkpoint_path), strict=False) + logger.info(f"Loaded finetune checkpoint located at: {checkpoint_path}") train_dataloader, val_dataloader = get_dataloaders( @@ -714,7 +717,7 @@ def train( batch_size=batch_size, num_workers=num_workers, apply_aug=True, - finetune=True if checkpoint_path is not None else False, + use_embeddings=use_embeddings, ) assert ( @@ -752,6 +755,7 @@ def train( model=model, train_dataloader=train_dataloader, val_dataloader=val_dataloader, + use_embeddings=use_embeddings, optimizer=optimizer, scheduler=scheduler, steps_per_checkpoint=steps_per_checkpoint, @@ -795,23 +799,30 @@ def _load_state_dict(_tokenizer: Tokenizer): def parse_resume_args(): argp = argparse.ArgumentParser(prog="python aria/train.py resume") argp.add_argument("model", help="name of model config file") - argp.add_argument("-train_data", nargs="+", help="path to train dir") - argp.add_argument("-val_data", help="path to val dir") - argp.add_argument("-cp_dir", help="checkpoint dir", type=str, required=True) - argp.add_argument("-r_step", help="resume step", type=int, required=True) - argp.add_argument("-r_epoch", help="resume epoch", type=int, required=True) - argp.add_argument("-epochs", help="train epochs", type=int, required=True) - argp.add_argument("-bs", help="batch size", type=int, default=32) argp.add_argument( - "-grad_acc_steps", + "--train_data", nargs="+", help="path to train dir", required=True + ) + argp.add_argument("--val_data", help="path to val dir", required=True) + argp.add_argument( + "--cp_dir", help="checkpoint dir", type=str, required=True + ) + argp.add_argument( + "--use_embeddings", help="prepend embeddings", action="store_true" + ) + argp.add_argument("--r_step", help="resume step", type=int, required=True) + argp.add_argument("--r_epoch", help="resume epoch", type=int, required=True) + argp.add_argument("--epochs", help="train epochs", type=int, required=True) + argp.add_argument("--bs", help="batch size", type=int, default=32) + argp.add_argument( + "--grad_acc_steps", help="gradient accumulation steps", type=int, default=1, ) - argp.add_argument("-workers", help="number workers", type=int, default=1) - argp.add_argument("-pdir", help="project dir", type=str, required=False) + argp.add_argument("--workers", help="number workers", type=int, default=1) + argp.add_argument("--pdir", help="project dir", type=str, required=False) argp.add_argument( - "-spc", help="steps per checkpoint", type=int, required=False + "--spc", help="steps per checkpoint", type=int, required=False ) return argp.parse_args(sys.argv[2:]) @@ -820,23 +831,28 @@ def parse_resume_args(): def parse_train_args(): argp = argparse.ArgumentParser(prog="python aria/train.py train") argp.add_argument("model", help="name of model config file") - argp.add_argument("-train_data", nargs="+", help="path to train dir") - argp.add_argument("-val_data", help="path to val dir") argp.add_argument( - "-cp_path", help="path to checkpoint", required=False, default=None + "--train_data", nargs="+", help="path to train dir", required=True + ) + argp.add_argument("--val_data", help="path to val dir", required=True) + argp.add_argument( + "--cp_path", help="path to checkpoint", required=False, default=None + ) + argp.add_argument( + "--use_embeddings", help="prepend embeddings", action="store_true" ) - argp.add_argument("-epochs", help="train epochs", type=int, required=True) - argp.add_argument("-bs", help="batch size", type=int, default=32) + argp.add_argument("--epochs", help="train epochs", type=int, required=True) + argp.add_argument("--bs", help="batch size", type=int, default=32) argp.add_argument( - "-grad_acc_steps", + "--grad_acc_steps", help="gradient accumulation steps", type=int, default=1, ) - argp.add_argument("-workers", help="number workers", type=int, default=1) - argp.add_argument("-pdir", help="project dir", type=str, required=False) + argp.add_argument("--workers", help="number workers", type=int, default=1) + argp.add_argument("--pdir", help="project dir", type=str, required=False) argp.add_argument( - "-spc", help="steps per checkpoint", type=int, required=False + "--spc", help="steps per checkpoint", type=int, required=False ) return argp.parse_args(sys.argv[2:]) @@ -860,6 +876,7 @@ def parse_train_args(): train( model_name=train_args.model, train_data_paths=train_args.train_data, + use_embeddings=train_args.use_embeddings, val_data_path=train_args.val_data, num_workers=train_args.workers, batch_size=train_args.bs, @@ -875,6 +892,7 @@ def parse_train_args(): model_name=resume_args.model, train_data_paths=resume_args.train_data, val_data_path=resume_args.val_data, + use_embeddings=resume_args.use_embeddings, num_workers=resume_args.workers, batch_size=resume_args.bs, grad_acc_steps=resume_args.grad_acc_steps, diff --git a/config/accelerate.yaml b/config/accelerate.yaml deleted file mode 100644 index 066be398..00000000 --- a/config/accelerate.yaml +++ /dev/null @@ -1,16 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: 'NO' -downcast_bf16: 'no' -gpu_ids: all -machine_rank: 0 -main_training_function: main -mixed_precision: bf16 -num_machines: 1 -num_processes: 1 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/config/config.json b/config/config.json index 0d32671e..dff8372a 100644 --- a/config/config.json +++ b/config/config.json @@ -140,7 +140,7 @@ "metadata": { "functions": { "aria_midi_json": { - "run": true, + "run": false, "args": {} }, "composer_filename": { @@ -174,39 +174,6 @@ "form": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"], "composer": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"] } - }, - "finetuning": { - "guidance_prob": 0.5, - "min_noisy_interval_ms": 5000, - "max_noisy_interval_ms": 60000, - "min_clean_interval_ms": 60000, - "max_clean_interval_ms": 200000, - "noising": { - "activation_prob": 0.5, - "remove_notes": { - "activation_prob": 0.25, - "min_ratio": 0.0, - "max_ratio": 0.15 - }, - "adjust_velocity": { - "activation_prob": 0.25, - "min_adjust": 1, - "max_adjust": 20 - }, - "adjust_onsets": { - "activation_prob": 0.25, - "min_adjust_s": 0.005, - "max_adjust_s": 0.05, - "max_ratio": 0.0, - "min_ratio": 0.2 - }, - "quantize_onsets": { - "activation_prob": 0.05, - "min_quant_s": 0.05, - "max_quant_s": 0.1, - "max_vel_delta": 30 - } - } } }, "tokenizer": { @@ -219,5 +186,4 @@ } } - } diff --git a/config/models/large.json b/config/models/large.json deleted file mode 100644 index 44014f30..00000000 --- a/config/models/large.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "d_model": 2048, - "n_heads": 32, - "n_layers": 16, - "ff_mult": 4, - "drop_p": 0.0, - "max_seq_len": 8192, - "grad_checkpoint": true -} \ No newline at end of file diff --git a/config/models/medium-composer.json b/config/models/medium-composer.json new file mode 100644 index 00000000..e3245c43 --- /dev/null +++ b/config/models/medium-composer.json @@ -0,0 +1,11 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "class_size": 10, + "resid_dropout": 0.2 +} diff --git a/config/models/medium-emb.json b/config/models/medium-emb.json new file mode 100644 index 00000000..f4f6579e --- /dev/null +++ b/config/models/medium-emb.json @@ -0,0 +1,10 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "emb_size": 512 +} diff --git a/config/models/medium-emotion.json b/config/models/medium-emotion.json new file mode 100644 index 00000000..27a896b7 --- /dev/null +++ b/config/models/medium-emotion.json @@ -0,0 +1,11 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "class_size": 4, + "resid_dropout": 0.2 +} diff --git a/config/models/medium-form.json b/config/models/medium-form.json new file mode 100644 index 00000000..2d9a656d --- /dev/null +++ b/config/models/medium-form.json @@ -0,0 +1,11 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "class_size": 6, + "resid_dropout": 0.2 +} diff --git a/config/models/medium-genre.json b/config/models/medium-genre.json new file mode 100644 index 00000000..31d2bdc6 --- /dev/null +++ b/config/models/medium-genre.json @@ -0,0 +1,11 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "class_size": 2, + "resid_dropout": 0.2 +} diff --git a/config/models/medium-music_period.json b/config/models/medium-music_period.json new file mode 100644 index 00000000..27a896b7 --- /dev/null +++ b/config/models/medium-music_period.json @@ -0,0 +1,11 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "class_size": 4, + "resid_dropout": 0.2 +} diff --git a/config/models/medium-pianist.json b/config/models/medium-pianist.json new file mode 100644 index 00000000..73b179b2 --- /dev/null +++ b/config/models/medium-pianist.json @@ -0,0 +1,11 @@ +{ + "d_model": 1536, + "n_heads": 24, + "n_layers": 16, + "ff_mult": 4, + "drop_p": 0.0, + "max_seq_len": 8192, + "grad_checkpoint": true, + "class_size": 8, + "resid_dropout": 0.2 +} diff --git a/demo/calibrate.py b/demo/calibrate.py new file mode 100644 index 00000000..74d07207 --- /dev/null +++ b/demo/calibrate.py @@ -0,0 +1,360 @@ +import argparse +import sys +import threading +import time + +import mido + +MIDDLE_C = 60 +C_MAJOR_CHORD = [MIDDLE_C, 64, 67, 72] # C4, E4, G4, C5 + + +def schedule_note_off(port: mido.ports.BaseOutput, note: int, delay: float): + """Schedules a non-blocking MIDI note-off message.""" + + def _off(): + port.send(mido.Message("note_off", note=note, velocity=0)) + + t = threading.Timer(delay, _off) + t.daemon = True # Allow main program to exit even if timers are pending + t.start() + + +def strike( + port: mido.ports.BaseOutput, velocity: int, offset_ms: int, notes: list[int] +): + """ + Performs a "3-2-1-GO!" countdown, sending MIDI notes with a precise offset. + The note-on message is sent `offset_ms` *before* "GO!" is printed. + """ + offset_sec = offset_ms / 1000.0 + + print("3") + time.sleep(1) + print("2") + time.sleep(1) + print("1") + + # Use monotonic time for a clock that is not affected by system time changes + go_time = time.monotonic() + 1.0 + note_on_time = go_time - offset_sec + + # Wait until the calculated time to send the MIDI message + sleep_duration = note_on_time - time.monotonic() + if sleep_duration > 0: + time.sleep(sleep_duration) + + for note in notes: + port.send(mido.Message("note_on", note=note, velocity=velocity)) + schedule_note_off(port, note, delay=0.5) + + # Wait for the exact moment to print "GO!" + sleep_duration = go_time - time.monotonic() + if sleep_duration > 0: + time.sleep(sleep_duration) + + print("GO!\n") + + +def note_repetition_trial( + port: mido.ports.BaseOutput, + velocity: int, + notes: list[int], + note_length_ms: int, + gap_ms: int, +): + """Plays a note or chord repeatedly for a 3-second trial period.""" + print("Playing 3-second loop...") + + note_length_sec = note_length_ms / 1000.0 + gap_sec = gap_ms / 1000.0 + end_time = time.monotonic() + 3.0 + + while time.monotonic() < end_time: + # Ensure there's enough time for one full note cycle before the end + if time.monotonic() + note_length_sec + gap_sec > end_time: + break + + for note in notes: + port.send(mido.Message("note_on", note=note, velocity=velocity)) + + time.sleep(note_length_sec) + + for note in notes: + port.send(mido.Message("note_off", note=note, velocity=0)) + + if gap_sec > 0: + time.sleep(gap_sec) + + print("...loop finished.\n") + + +def calibrate_output_latency( + port_name: str, + velocity: int, + step_ms: int, + initial_offset_ms: int, + chord_mode: bool, +): + """Interactive loop to find the ideal hardware latency offset.""" + notes = C_MAJOR_CHORD if chord_mode else [MIDDLE_C] + offset_ms = initial_offset_ms + + try: + with mido.open_output(port_name) as port: + print(f"Opened MIDI output: {port_name}\n") + while True: + strike(port, velocity, offset_ms, notes) + print(f"Current offset: {offset_ms} ms") + cmd = ( + input("[u]p / [d]own / [r]epeat / [q]uit: ").strip().lower() + ) + + if cmd == "u": + offset_ms += step_ms + elif cmd == "d": + offset_ms = max(0, offset_ms - step_ms) + elif cmd == "q": + break + # Any other key (incl. 'r' or enter) repeats the trial + print() + except (KeyboardInterrupt, SystemExit): + print("\nInterrupted — exiting.") + except Exception as e: + print(f"\nAn error occurred: {e}") + + +def calibrate_note_timing( + port_name: str, + velocity: int, + step_ms: int, + note_length_ms: int, + initial_gap_ms: int, + chord_mode: bool, +): + """Interactive loop to find a comfortable note repetition speed.""" + notes = C_MAJOR_CHORD if chord_mode else [MIDDLE_C] + gap_ms = initial_gap_ms + + try: + with mido.open_output(port_name) as port: + print(f"Opened MIDI output: {port_name}\n") + while True: + note_repetition_trial( + port, velocity, notes, note_length_ms, gap_ms + ) + print(f"Current gap: {gap_ms} ms") + cmd = ( + input("[u]p / [d]own / [r]epeat / [q]uit: ").strip().lower() + ) + + if cmd == "u": + gap_ms += step_ms + elif cmd == "d": + gap_ms = max(0, gap_ms - step_ms) + elif cmd == "q": + break + print() + except (KeyboardInterrupt, SystemExit): + print("\nInterrupted — exiting.") + except Exception as e: + print(f"\nAn error occurred: {e}") + + +def measure_input_latency(port_name: str, timeout_sec: float = 2.0): + """ + 3-2-1-GO countdown → you strike a key on GO. + Prints the latency (note_on arrival – GO). + + • Uses the same MIDI port for input. + • Waits `timeout_sec` seconds for a note-on; repeats if none arrives. + """ + try: + with mido.open_ioport(port_name) as port: + print(f"Opened MIDI I/O port: {port_name}\n") + + while True: + # ── simple countdown ──────────────────────────────────── + for n in ("3", "2", "1"): + print(n) + time.sleep(1) + + go_time = time.monotonic() + print("GO!") + + # wait for first note-on (velocity>0) or timeout + deadline = go_time + timeout_sec + latency_ms = None + while time.monotonic() < deadline: + msg = port.poll() + if msg and msg.type == "note_on" and msg.velocity > 0: + latency_ms = (time.monotonic() - go_time) * 1000.0 + break + + if latency_ms is None: + print("No key press detected – try again.\n") + else: + print(f"Input latency: {latency_ms:.1f} ms\n") + + if input("[r]etry / [q]uit: ").strip().lower() == "q": + break + print() + + except (KeyboardInterrupt, SystemExit): + print("\nInterrupted — exiting.") + except Exception as e: + print(f"\nAn error occurred: {e}") + + +def list_midi_ports() -> None: + """Prints a list of available MIDI output ports.""" + print("Available MIDI output ports:") + try: + port_names = mido.get_output_names() + if not port_names: + print(" (No ports found)") + for name in port_names: + print(f" - {name}") + except Exception as e: + print(f"Could not retrieve MIDI ports: {e}") + + +def parse_args(): + """Parses command-line arguments for the calibration tool.""" + parser = argparse.ArgumentParser( + description="A tool to calibrate Disklavier latency and note timing.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # ── global option ────────────────────────────────────────────────────── + parser.add_argument( + "--list-ports", + action="store_true", + help="List available MIDI output ports and exit.", + ) + + # ── options common to *all* modes ───────────────────────────────────── + parent = argparse.ArgumentParser(add_help=False) + parent.add_argument("--port", "-p", required=True, help="MIDI port name.") + parent.add_argument( + "--velocity", + "-v", + type=int, + default=80, + help="Note-on velocity (1-127).", + ) + parent.add_argument( + "--step", + "-s", + type=int, + default=10, + help="Adjustment step in ms (latency/timing modes).", + ) + parent.add_argument( + "--chord", + "-c", + action="store_true", + help="Use a C-major chord instead of single note.", + ) + + sub = parser.add_subparsers(dest="command", help="Available commands.") + + # ── output-latency calibration ──────────────────────────────────────── + p_lat = sub.add_parser( + "output", + parents=[parent], + help="Calibrate output latency.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p_lat.add_argument( + "--offset", + "-o", + type=int, + default=100, + help="Initial latency offset in ms.", + ) + + # ── repeated-note timing calibration ────────────────────────────────── + p_tim = sub.add_parser( + "timing", + parents=[parent], + help="Calibrate minimum gap between notes.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p_tim.add_argument( + "--note-length", + "-l", + type=int, + default=500, + help="Note duration in ms.", + ) + p_tim.add_argument( + "--gap", + "-g", + type=int, + default=100, + help="Initial gap between notes in ms.", + ) + + # ── input-latency measurement (new) ─────────────────────────────────── + p_in = sub.add_parser( + "input", + parents=[parent], + help="Measure input latency (countdown → strike).", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p_in.add_argument( + "--timeout", + "-t", + type=float, + default=2.0, + help="Seconds to wait for a key press before retry.", + ) + + args = parser.parse_args() + + # global flag handler + if args.list_ports: + list_midi_ports() + sys.exit(0) + + if not args.command: + parser.error( + "A command is required: choose 'output', 'timing', or 'input'." + ) + + return args + + +def main(): + """Dispatches to the selected calibration or measurement routine.""" + args = parse_args() + + if args.command == "output": + calibrate_output_latency( + port_name=args.port, + velocity=args.velocity, + step_ms=args.step, + initial_offset_ms=args.offset, + chord_mode=args.chord, + ) + + elif args.command == "timing": + calibrate_note_timing( + port_name=args.port, + velocity=args.velocity, + step_ms=args.step, + note_length_ms=args.note_length, + initial_gap_ms=args.gap, + chord_mode=args.chord, + ) + + elif args.command == "input": + measure_input_latency( + port_name=args.port, + timeout_sec=args.timeout, + ) + + +if __name__ == "__main__": + main() diff --git a/demo/demo.py b/demo/demo.py new file mode 100644 index 00000000..46a3a0b1 --- /dev/null +++ b/demo/demo.py @@ -0,0 +1,1598 @@ +#!/usr/bin/env python3 + +import argparse +import os +import time +import uuid +import copy +import logging +import threading +import queue +import copy +import torch +import mido +import torch._inductor.config + +from torch.cuda import is_available as cuda_is_available +from contextlib import ExitStack + +from ariautils.midi import MidiDict, midi_to_dict +from ariautils.tokenizer import AbsTokenizer +from aria.utils import _load_weight +from aria.inference import TransformerLM +from aria.model import ModelConfig +from aria.config import load_model_config +from aria.sample import sample_min_p + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True + +DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 +MAX_SEQ_LEN = 4096 +PREFILL_CHUNK_SIZE = 32 +RECALC_DUR_PREFILL_CHUNK_SIZE = 8 +RECALC_DUR_BUFFER_MS = 50 + +# Decode first +BEAM_WIDTH = 3 +TIME_TOK_WEIGHTING = -5 +FIRST_ONSET_BUFFER_MS = 25 + +# HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS +# HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg +# HARDWARE: All messages are sent HARDWARE_LATENCY_MS early +MIN_NOTE_DELTA_MS = 100 +MIN_NOTE_LEN_MS = 200 +HARDWARE_LATENCY_MS = 0 + +file_handler = logging.FileHandler("./demo.log", mode="w") +file_handler.setLevel(logging.DEBUG) + + +def get_logger(name: str | None = None) -> logging.Logger: + logger = logging.getLogger(name) + if not logger.handlers: + logger.propagate = False + logger.setLevel(logging.DEBUG) + + class MillisecondFormatter(logging.Formatter): + def formatTime(self, record, datefmt=None): + created_ms = int(record.created * 1000) + return str(created_ms) + + if name is not None: + formatter = MillisecondFormatter( + "%(asctime)s: [%(levelname)s] [%(name)s] %(message)s" + ) + else: + formatter = MillisecondFormatter( + "%(asctime)s: [%(levelname)s] %(message)s" + ) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +def get_epoch_time_ms() -> int: + return round(time.time() * 1000) + + +@torch.autocast("cuda", dtype=DTYPE) +@torch.inference_mode() +def prefill( + model: TransformerLM, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, +) -> torch.Tensor: + logits = model.forward( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, + ) + + return logits + + +@torch.autocast("cuda", dtype=DTYPE) +@torch.inference_mode() +def decode_one( + model: TransformerLM, + idxs: torch.Tensor, + input_pos: torch.Tensor, + pad_idxs: torch.Tensor | None = None, +) -> torch.Tensor: + assert input_pos.shape[-1] == 1 + + logits = model.forward( + idxs=idxs, + input_pos=input_pos, + pad_idxs=pad_idxs, + )[:, -1] + + return logits + + +def _compile_prefill( + model: TransformerLM, + logger: logging.Logger, + chunk_size: int, +): + assert chunk_size > 1 + + global prefill + prefill = torch.compile( + prefill, + mode="reduce-overhead", + fullgraph=True, + ) + start_compile_time_s = time.time() + logger.info(f"Compiling prefill (chunk_size={chunk_size})") + prefill( + model, + idxs=torch.ones(1, chunk_size, device="cuda", dtype=torch.int), + input_pos=torch.arange(0, chunk_size, device="cuda", dtype=torch.int), + ) + logger.info( + f"Finished compiling - took {time.time() - start_compile_time_s:.4f} seconds" + ) + + for _ in range(5): + prefill( + model, + idxs=torch.ones(1, chunk_size, device="cuda", dtype=torch.int), + input_pos=torch.arange( + 0, chunk_size, device="cuda", dtype=torch.int + ), + ) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + prefill( + model, + idxs=torch.ones(1, chunk_size, device="cuda", dtype=torch.int), + input_pos=torch.arange(0, chunk_size, device="cuda", dtype=torch.int), + ) + end_event.record() + end_event.synchronize() + compiled_prefill_ms = start_event.elapsed_time(end_event) + compiled_prefill_its = 1000 / compiled_prefill_ms + logger.info( + f"Compiled prefill benchmark: {compiled_prefill_ms:.2f} ms/it ({compiled_prefill_its:.2f} it/s)" + ) + + return model + + +def _compile_decode_one(model: TransformerLM, logger: logging.Logger): + global decode_one + decode_one = torch.compile( + decode_one, + mode="reduce-overhead", + fullgraph=True, + ) + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + start_compile_time_s = time.time() + logger.info(f"Compiling decode_one") + decode_one( + model, + idxs=torch.tensor([[0]], device="cuda", dtype=torch.int), + input_pos=torch.tensor([0], device="cuda", dtype=torch.int), + ) + logger.info( + f"Finished compiling - took {time.time() - start_compile_time_s:.4f} seconds" + ) + + for _ in range(5): + decode_one( + model, + idxs=torch.tensor([[0]], device="cuda", dtype=torch.int).cuda(), + input_pos=torch.tensor([0], device="cuda", dtype=torch.int), + ) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + decode_one( + model, + idxs=torch.tensor([[0]], device="cuda", dtype=torch.int).cuda(), + input_pos=torch.tensor([0], device="cuda", dtype=torch.int), + ) + end_event.record() + end_event.synchronize() + + compiled_forward_ms = start_event.elapsed_time(end_event) + compiled_forward_its = 1000 / compiled_forward_ms + logger.info( + f"Compiled decode_one benchmark: {compiled_forward_ms:.2f} ms/it ({compiled_forward_its:.2f} it/s)" + ) + + return model + + +@torch.inference_mode() +def compile_model(model: TransformerLM, max_seq_len: int): + logger = get_logger() + assert 10 < max_seq_len <= MAX_SEQ_LEN + + model.eval() + model.setup_cache( + batch_size=1, + max_seq_len=max_seq_len, + dtype=DTYPE, + ) + + model = _compile_decode_one(model=model, logger=logger) + for chunk_size in list({PREFILL_CHUNK_SIZE, RECALC_DUR_PREFILL_CHUNK_SIZE}): + model = _compile_prefill( + model=model, logger=logger, chunk_size=chunk_size + ) + + return model + + +def load_model( + checkpoint_path: str, +): + logger = get_logger() + if not cuda_is_available(): + raise Exception("CUDA device is not available.") + + init_start_time_s = time.time() + + tokenizer = AbsTokenizer() + model_config = ModelConfig(**load_model_config("medium-emb")) + model_config.set_vocab_size(tokenizer.vocab_size) + model_config.grad_checkpoint = False + model = TransformerLM(model_config).cuda() + + logging.info(f"Loading model weights from {checkpoint_path}") + model_state = _load_weight(checkpoint_path, "cuda") + model_state = { + k.replace("_orig_mod.", ""): v for k, v in model_state.items() + } + try: + model.load_state_dict(model_state) + except Exception: + logger.info("Failed to load model, attempting with strict=False...") + model.load_state_dict(model_state, strict=False) + + logger.info( + f"Finished initializing model - took {time.time() - init_start_time_s:.4f} seconds" + ) + + return model + + +def _first_bad_dur_index( + tokenizer: AbsTokenizer, + priming_seq: list, + pred_ids: list, + chunk_start: int, + last_offset_ms: int, + logger: logging.Logger, +): + num_time_toks = priming_seq[:chunk_start].count(tokenizer.time_tok) + local_onset_ms = tokenizer.calc_length_ms( + priming_seq[:chunk_start], onset=True + ) + logger.debug(f"Starting from local onset {local_onset_ms}") + + for pos, tok_id in enumerate( + pred_ids[: len(priming_seq) - chunk_start], start=chunk_start + ): + prim_tok = priming_seq[pos] # Should never error? + pred_tok = tokenizer.id_to_tok[tok_id] + logger.debug(f"prim={prim_tok}, pred={pred_tok}") + + if isinstance(prim_tok, tuple) and prim_tok[0] == "onset": + local_onset_ms = num_time_toks * 5000 + prim_tok[1] + elif prim_tok == tokenizer.time_tok: + num_time_toks += 1 + elif isinstance(prim_tok, tuple) and prim_tok[0] == "dur": + dur_true = prim_tok[1] + dur_pred = pred_tok[1] + if dur_pred > dur_true and ( + local_onset_ms + dur_true + > last_offset_ms - RECALC_DUR_BUFFER_MS + ): + logger.info( + f"Found token to resample at {pos}: {prim_tok} -> {pred_tok}" + ) + return pos + + return None + + +# TODO: I'm still not 100% sure this is bug free. +# A good debugging strat would be to run it over and over again until we +# cover all of the edge cases +@torch.inference_mode() +def recalc_dur_tokens_chunked( + model: TransformerLM, + priming_seq: list, + enc_seq: torch.Tensor, + tokenizer: AbsTokenizer, + start_idx: int, +): + """Speculative-decoding inspired duration re-calculation""" + assert start_idx > 0 + logger = get_logger("GENERATE") + + priming_len = len(priming_seq) + last_offset = tokenizer.calc_length_ms(priming_seq) + + idx = start_idx + while idx <= priming_len: + end_idx = idx + RECALC_DUR_PREFILL_CHUNK_SIZE + + window_ids = torch.tensor( + enc_seq[:, idx - 1 : end_idx - 1].tolist(), + device="cuda", + dtype=torch.int, + ) + window_pos = torch.arange( + idx - 1, end_idx - 1, device="cuda", dtype=torch.int + ) + + logger.info( + f"Recalculating chunked durations for positions: {idx-1} - {end_idx-2}" + ) + logger.debug(f"Inserted: {tokenizer.decode(window_ids[0].tolist())}") + logger.debug(f"Positions: {window_pos.tolist()}") + + logits = prefill(model, idxs=window_ids, input_pos=window_pos) + pred_ids = logits.argmax(dim=-1).flatten().tolist() + + bad_pos = _first_bad_dur_index( + tokenizer=tokenizer, + priming_seq=priming_seq, + pred_ids=pred_ids, + chunk_start=idx, + last_offset_ms=last_offset, + logger=logger, + ) + + if bad_pos is None: + idx = end_idx + else: + new_id = pred_ids[bad_pos - idx] + enc_seq[0, bad_pos] = new_id + priming_seq[bad_pos] = tokenizer.id_to_tok[new_id] + idx = bad_pos + + next_logits = logits[:, priming_len - idx] + + return enc_seq, priming_seq, next_logits + + +# TODO: This is now the latency bottleneck. +# Ideas for reducing it: +# - Get rid of the manual time_tok insert stuff, instead just mask logits +# for all invalid tokens, this should force the model to sample a time tok +# if there aren't any other valid options +@torch.inference_mode() +def decode_first_tokens( + model: TransformerLM, + first_token_logits: torch.Tensor, + enc_seq: torch.Tensor, + priming_seq: list, + tokenizer: AbsTokenizer, + generated_tokens_queue: queue.Queue, + first_on_msg_epoch_ms: int, +): + logger = get_logger("GENERATE") + + buffer_ms = FIRST_ONSET_BUFFER_MS + HARDWARE_LATENCY_MS + time_tok_id = tokenizer.tok_to_id[tokenizer.time_tok] + + logits = first_token_logits + time_since_first_onset_ms = get_epoch_time_ms() - first_on_msg_epoch_ms + idx = len(priming_seq) + 1 + + num_time_toks_required = (time_since_first_onset_ms + buffer_ms) // 5000 + num_time_toks_in_priming_seq = priming_seq.count(tokenizer.time_tok) + num_time_toks_to_add = num_time_toks_required - num_time_toks_in_priming_seq + + logger.info(f"Time since first onset: {time_since_first_onset_ms}ms") + + while num_time_toks_to_add > 0: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + generated_tokens_queue.put(tokenizer.time_tok) + logits = decode_one( + model, + idxs=torch.tensor( + [[time_tok_id]], device="cuda", dtype=torch.int + ), + input_pos=torch.tensor( + [idx - 1], device="cuda", dtype=torch.int + ), + ) + + logger.info(f"Inserted time_tok at position {idx-1}") + num_time_toks_to_add -= 1 + enc_seq[:, idx - 1] = torch.tensor([[time_tok_id]]).cuda() + idx += 1 + + logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") + + log_probs = torch.log_softmax(logits, dim=-1) + top_log_probs, top_ids = torch.topk(log_probs, k=BEAM_WIDTH, dim=-1) + + if time_tok_id not in top_ids[0].tolist(): + top_ids[0, -1] = time_tok_id + top_log_probs[0, -1] = log_probs[0, time_tok_id] + TIME_TOK_WEIGHTING + + top_toks = [tokenizer.id_to_tok[id] for id in top_ids[0].tolist()] + + logger.debug(f"Calculated top {BEAM_WIDTH} tokens={top_toks}") + logger.debug( + f"Calculated top {BEAM_WIDTH} scores={top_log_probs[0].tolist()}" + ) + + masked_onset_ids = [ + tokenizer.tok_to_id[tok] + for tok in tokenizer.onset_tokens + if tok[1] < ((time_since_first_onset_ms + buffer_ms) % 5000) + ] + + logger.debug( + f"Masking onsets for {len(masked_onset_ids)} tokens ({time_since_first_onset_ms + buffer_ms})" + ) + + best_score = float("-inf") + for i in range(BEAM_WIDTH): + tok = top_toks[i] + tok_id = top_ids[0, i].item() + tok_log_prob = top_log_probs[0, i] + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + next_logits = decode_one( + model, + idxs=torch.tensor([[tok_id]], device="cuda", dtype=torch.int), + input_pos=torch.tensor( + [idx - 1], device="cuda", dtype=torch.int + ), + ) + logger.debug( + f"Sampled logits for positions {idx} by inserting {tok} at position {idx-1}" + ) + + next_log_probs = torch.log_softmax(next_logits, dim=-1) + next_log_probs[:, masked_onset_ids] = float("-inf") + if tok_id == time_tok_id: + next_log_probs[:, time_tok_id] = float("-inf") + + next_tok_log_prob, next_tok_id = torch.max(next_log_probs, dim=-1) + next_tok = tokenizer.id_to_tok[next_tok_id.item()] + score = tok_log_prob + next_tok_log_prob + + logger.info( + f"Calculated tuple {(tok, next_tok)} with scores {(tok_log_prob.item(), next_tok_log_prob.item())} (combined={score.item()})" + ) + + if score > best_score: + best_tok_id_1, best_tok_id_2 = tok_id, next_tok_id.item() + best_tok_1, best_tok_2 = ( + tokenizer.id_to_tok[best_tok_id_1], + tokenizer.id_to_tok[best_tok_id_2], + ) + best_score = score + + logger.info( + f"Chose tuple {(best_tok_1, best_tok_2)} with score {best_score.item()}" + ) + + enc_seq[:, idx - 1] = best_tok_id_1 + enc_seq[:, idx] = best_tok_id_2 + generated_tokens_queue.put(tokenizer.id_to_tok[best_tok_id_1]) + generated_tokens_queue.put(tokenizer.id_to_tok[best_tok_id_2]) + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + decode_one( + model, + idxs=torch.tensor( + [[best_tok_id_1]], device="cuda", dtype=torch.int + ), + input_pos=torch.tensor([idx - 1], device="cuda", dtype=torch.int), + ) + + logger.info( + f"Updated KV-Cache by re-inserting {best_tok_1} at position {idx-1}" + ) + logger.info( + f"Inserted {best_tok_2} at position {idx} without updating KV-Cache" + ) + + return enc_seq, idx + 1 + + +def decode_tokens( + model: TransformerLM, + enc_seq: torch.Tensor, + tokenizer: AbsTokenizer, + control_sentinel: threading.Event, + generated_tokens_queue: queue.Queue, + idx: int, + temperature: float, + min_p: float, +): + logger = get_logger("GENERATE") + logger.info( + f"Using sampling parameters: temperature={temperature}, min_p={min_p}" + ) + + while (not control_sentinel.is_set()) and idx < MAX_SEQ_LEN: + decode_one_start_time_s = time.time() + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + prev_tok_id = enc_seq[0, idx - 1] + prev_tok = tokenizer.id_to_tok[prev_tok_id.item()] + + logits = decode_one( + model, + idxs=torch.tensor( + [[prev_tok_id]], device="cuda", dtype=torch.int + ), + input_pos=torch.tensor( + [idx - 1], device="cuda", dtype=torch.int + ), + ) + + logger.debug( + f"Sampled logits for positions {idx} by inserting {prev_tok} at position {idx-1}" + ) + + logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + for dur_ms in range(0, MIN_NOTE_LEN_MS, 10): + logits[:, tokenizer.tok_to_id[("dur", dur_ms)]] = float("-inf") + + if temperature > 0.0: + probs = torch.softmax(logits / temperature, dim=-1) + next_token_ids = sample_min_p(probs, min_p).flatten() + else: + next_token_ids = torch.argmax(logits, dim=-1).flatten() + + enc_seq[:, idx] = next_token_ids + next_token = tokenizer.id_to_tok[next_token_ids[0].item()] + logger.debug( + f"({(time.time() - decode_one_start_time_s)*1000:.2f}ms) {idx}: {next_token}" + ) + + if next_token == tokenizer.eos_tok: + logger.info("EOS token produced, exiting...") + generated_tokens_queue.put(next_token) + return + else: + generated_tokens_queue.put(next_token) + idx += 1 + + while not control_sentinel.is_set(): + time.sleep(0.1) + + logger.info("Seen exit signal") + generated_tokens_queue.put(None) + + +@torch.inference_mode() +def generate_tokens( + priming_seq: list, + tokenizer: AbsTokenizer, + model: TransformerLM, + prev_context: list[int], + control_sentinel: threading.Event, + generated_tokens_queue: queue.Queue, + num_preceding_active_pitches: int, + first_on_msg_epoch_ms: int, + temperature: float = 0.97, + min_p: float = 0.03, +): + logger = get_logger("GENERATE") + + generate_start_s = time.time() + priming_seq_len = len(priming_seq) + start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches - 1) + enc_seq = torch.tensor( + [ + tokenizer.encode( + priming_seq + + [tokenizer.pad_tok] * (MAX_SEQ_LEN - len(priming_seq)) + ) + ], + device="cuda", + dtype=torch.int, + ) + + logger.debug(f"Priming sequence {priming_seq}") + logger.info(f"Priming sequence length: {priming_seq_len}") + logger.info(f"Prefilling up to (and including) position: {start_idx-1}") + + # In theory we could reuse the logits from prefill + prefill_start_s = time.time() + chunked_prefill( + model=model, + tokenizer=tokenizer, + prev_context=prev_context, + curr_context=enc_seq[0, :start_idx].tolist(), + full=True, + ) + + torch.cuda.synchronize() + logger.info( + f"Prefill took {(time.time() - prefill_start_s) * 1000:.2f} milliseconds" + ) + logger.info(f"Starting duration recalculation from position: {start_idx-1}") + + recalculate_dur_start_s = time.time() + enc_seq, priming_seq, next_token_logits = recalc_dur_tokens_chunked( + model=model, + priming_seq=priming_seq, + enc_seq=enc_seq, + tokenizer=tokenizer, + start_idx=start_idx, + ) + + logger.info( + f"Recalculating durations took {(time.time() - recalculate_dur_start_s) * 1000:.2f} milliseconds" + ) + + decode_first_s = time.time() + enc_seq, idx = decode_first_tokens( + model=model, + first_token_logits=next_token_logits, + enc_seq=enc_seq, + priming_seq=priming_seq, + tokenizer=tokenizer, + generated_tokens_queue=generated_tokens_queue, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + ) + + logger.info( + f"Decode first two tokens took {(time.time() - decode_first_s) * 1000:.2f} milliseconds" + ) + logger.info( + f"Time to first token took {(time.time() - generate_start_s) * 1000:.2f} milliseconds" + ) + + decode_tokens( + model=model, + enc_seq=enc_seq, + tokenizer=tokenizer, + control_sentinel=control_sentinel, + generated_tokens_queue=generated_tokens_queue, + idx=idx, + temperature=temperature, + min_p=min_p, + ) + + +def decode_tokens_to_midi( + generated_tokens_queue: queue.Queue, + outbound_midi_msg_queue: queue.Queue, + tokenizer: AbsTokenizer, + first_on_msg_epoch_ms: int, + priming_seq_last_onset_ms: int, +): + logger = get_logger("DECODE") + + assert ( + first_on_msg_epoch_ms + priming_seq_last_onset_ms < get_epoch_time_ms() + ) + + logger.info(f"Priming sequence last onset: {priming_seq_last_onset_ms}") + logger.info( + f"Total time elapsed since first onset: {get_epoch_time_ms() - first_on_msg_epoch_ms}" + ) + + pitch_to_prev_msg = {} + note_buffer = [] + num_time_toks = priming_seq_last_onset_ms // 5000 + + while True: + while True: + tok = generated_tokens_queue.get() + if tok is tokenizer.eos_tok: + _uuid = uuid.uuid4() + end_msg = { + "pitch": -1, + "vel": -1, + "epoch_time_ms": offset_epoch_ms + 250, # Last note offset + "uuid": _uuid, + } # pitch=-1 denotes end_msg + outbound_midi_msg_queue.put(end_msg) + logger.info(f"Seen exit signal: EOS token") + logger.debug(f"Put message: {end_msg}") + return + + elif tok is None: + logger.info(f"Seen exit signal") + return + + logger.debug(f"Seen token: {tok}") + note_buffer.append(tok) + + if isinstance(tok, tuple) and tok[0] == "dur": + break + + while note_buffer and note_buffer[0] == tokenizer.time_tok: + logger.debug("Popping time_tok") + num_time_toks += 1 + note_buffer.pop(0) + + assert len(note_buffer) == 3 + logger.debug(f"Decoded note: {note_buffer}") + note_tok, onset_tok, dur_tok = note_buffer + _, pitch, vel = note_tok + _, onset = onset_tok + _, dur = dur_tok + + _uuid = uuid.uuid4() + onset_epoch_ms = first_on_msg_epoch_ms + (num_time_toks * 5000) + onset + offset_epoch_ms = onset_epoch_ms + dur + on_msg = { + "pitch": pitch, + "vel": vel, + "epoch_time_ms": onset_epoch_ms, + "uuid": _uuid, + } + off_msg = { + "pitch": pitch, + "vel": 0, + "epoch_time_ms": offset_epoch_ms, + "uuid": _uuid, + } + + # Not thread safe but in theory should be ok? + if pitch_to_prev_msg.get(pitch) is not None and MIN_NOTE_DELTA_MS > 0: + prev_on, prev_off = pitch_to_prev_msg.get(pitch) + adj_off_time = max( + min( + prev_off["epoch_time_ms"], + onset_epoch_ms - MIN_NOTE_DELTA_MS, + ), + prev_on["epoch_time_ms"], + ) + if adj_off_time != prev_off["epoch_time_ms"]: + logger.debug(f"Adjusting {prev_off}: t={adj_off_time}") + prev_off["epoch_time_ms"] = adj_off_time + prev_off["adjusted"] = True + + pitch_to_prev_msg[pitch] = [on_msg, off_msg] + + outbound_midi_msg_queue.put(on_msg) + outbound_midi_msg_queue.put(off_msg) + logger.debug(f"Put message: {on_msg}") + logger.debug(f"Put message: {off_msg}") + logger.debug(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") + + note_buffer = [] + + +# TODO: Test the new changes in decode_tokens_to_midi and clean this fn up. +def stream_midi( + inbound_midi_msg_queue: queue.Queue, + msgs: list[mido.Message], + prev_msg_epoch_time_ms: float, + midi_output_port: str, + control_sentinel: threading.Event, + midi_stream_channel: int, + results_queue: queue.Queue, +): + logger = get_logger("STREAM") + logger.info( + f"Sending generated messages on MIDI port: '{midi_output_port}'" + ) + logger.info( + f"Applying hardware latency adjustment: {HARDWARE_LATENCY_MS}ms" + ) + MAX_DELAY_MS = 50 + + active_pitch_uuid = {} + is_pitch_active = {} + midi_msgs = [] + + with mido.open_output(midi_output_port) as midi_out: + while not control_sentinel.is_set(): + while True: + try: + msg = inbound_midi_msg_queue.get_nowait() + except queue.Empty: + break + else: + logger.debug(f"Received message: {msg}") + midi_msgs.append(msg) + + midi_msgs = sorted( + midi_msgs, + key=lambda msg: ( + msg["epoch_time_ms"], + msg["vel"], + ), + ) + + if control_sentinel.is_set(): + break + + while midi_msgs: + latency_adjusted_epoch_time_ms = ( + get_epoch_time_ms() + HARDWARE_LATENCY_MS + ) + msg = midi_msgs[0] + + if ( + 0 + < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] + <= MAX_DELAY_MS + ): + if msg["pitch"] == -1: # End msg + control_sentinel.set() + break + + mido_msg = mido.Message( + "note_on", + note=msg["pitch"], + velocity=msg["vel"], + channel=0, + time=0, + ) + + if msg["vel"] > 0: + active_pitch_uuid[msg["pitch"]] = msg["uuid"] + should_send_midi_out = True + should_append_to_msgs = True + elif msg.get("adjusted", False) is True: + should_send_midi_out = True + should_append_to_msgs = False + else: + should_send_midi_out = ( + active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] + ) + should_append_to_msgs = should_send_midi_out + + if should_send_midi_out is True: + midi_out.send(mido_msg) + is_pitch_active[msg["pitch"]] = msg["vel"] != 0 + logger.info(f"Sent message: {mido_msg}") + if should_append_to_msgs is True: + mido_msg_with_time = copy.deepcopy(mido_msg) + mido_msg_with_time.channel = midi_stream_channel + mido_msg_with_time.time = max( + 0, msg["epoch_time_ms"] - prev_msg_epoch_time_ms + ) + prev_msg_epoch_time_ms = msg["epoch_time_ms"] + msgs.append(mido_msg_with_time) + + midi_msgs.pop(0) + + elif ( + latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] + > MAX_DELAY_MS + ): + # Message occurs too far in the past + logger.debug( + f"Skipping message occurring too far ({latency_adjusted_epoch_time_ms - msg["epoch_time_ms"]}ms) in the past: {msg}" + ) + midi_msgs.pop(0) + else: + # Message occurs in the future + break + + time.sleep(0.005) + + remaining_note_off_messages = [ + msg + for msg in midi_msgs + if msg["vel"] == 0 + and active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] + ] + + logger.info("Processing remaining note_off messages") + for __msg in remaining_note_off_messages: + logger.debug(remaining_note_off_messages) + + for msg in remaining_note_off_messages: + mido_msg = mido.Message( + "note_on", + note=msg["pitch"], + velocity=0, + channel=midi_stream_channel, + time=msg["epoch_time_ms"] - prev_msg_epoch_time_ms, + ) + prev_msg_epoch_time_ms = msg["epoch_time_ms"] + msgs.append(mido_msg) + + results_queue.put(msgs) + + while remaining_note_off_messages: + msg = remaining_note_off_messages.pop(0) + while True: + latency_adjusted_epoch_time_ms = ( + get_epoch_time_ms() + HARDWARE_LATENCY_MS + ) + + if 0 < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"]: + mido_msg = mido.Message( + "note_on", + note=msg["pitch"], + velocity=0, + channel=midi_stream_channel, + time=0, # Does not matter as only used for streaming + ) + midi_out.send(mido_msg) + logger.info(f"Sent message: {mido_msg}") + break + else: + time.sleep(0.01) + + +def stream_msgs( + model: TransformerLM, + tokenizer: AbsTokenizer, + msgs: list[mido.Message], + prev_context: list[int], + midi_output_port: str, + first_on_msg_epoch_ms: int, + control_sentinel: threading.Event, + temperature: float, + min_p: float, + num_preceding_active_pitches: int, + midi_stream_channel: int, + is_ending: bool = False, +): + midi = convert_msgs_to_midi(msgs=msgs) + midi_dict = MidiDict(**midi_to_dict(midi)) + priming_seq = tokenizer.tokenize(midi_dict=midi_dict, add_dim_tok=False) + priming_seq = priming_seq[: priming_seq.index(tokenizer.eos_tok)] + + if is_ending is True: + priming_seq.append(tokenizer.dim_tok) + + generated_tokens_queue = queue.Queue() + midi_messages_queue = queue.Queue() + + generate_tokens_thread = threading.Thread( + target=generate_tokens, + kwargs={ + "priming_seq": priming_seq, + "tokenizer": tokenizer, + "model": model, + "prev_context": prev_context, + "control_sentinel": control_sentinel, + "generated_tokens_queue": generated_tokens_queue, + "temperature": temperature, + "min_p": min_p, + "num_preceding_active_pitches": num_preceding_active_pitches, + "first_on_msg_epoch_ms": first_on_msg_epoch_ms, + }, + ) + generate_tokens_thread.start() + + decode_tokens_to_midi_thread = threading.Thread( + target=decode_tokens_to_midi, + kwargs={ + "generated_tokens_queue": generated_tokens_queue, + "outbound_midi_msg_queue": midi_messages_queue, + "tokenizer": tokenizer, + "first_on_msg_epoch_ms": first_on_msg_epoch_ms, + "priming_seq_last_onset_ms": tokenizer.calc_length_ms( + priming_seq, onset=True + ), + }, + ) + decode_tokens_to_midi_thread.start() + + prev_ms_epoch_time_ms = ( + first_on_msg_epoch_ms + + tokenizer.calc_length_ms(priming_seq, onset=False) + if is_ending is False + else first_on_msg_epoch_ms + ) + + stream_midi_results_queue = queue.Queue() + stream_midi_thread = threading.Thread( + target=stream_midi, + kwargs={ + "inbound_midi_msg_queue": midi_messages_queue, + "msgs": msgs, + "prev_msg_epoch_time_ms": prev_ms_epoch_time_ms, + "midi_output_port": midi_output_port, + "control_sentinel": control_sentinel, + "midi_stream_channel": midi_stream_channel, + "results_queue": stream_midi_results_queue, + }, + daemon=True, + ) + stream_midi_thread.start() + + generate_tokens_thread.join() + decode_tokens_to_midi_thread.join() + msgs = stream_midi_results_queue.get() + + if is_ending is True: + stream_midi_thread.join() + + return msgs + + +# TODO: Channel 9 issues here? +def convert_msgs_to_midi(msgs: list[mido.Message]): + channel_to_track = { + chan: mido.MidiTrack() + for chan in list(set([msg.channel for msg in msgs])) + } + + for msg in msgs: + channel_to_track[msg.channel].append(msg) + + # Workaround for possibility that track_0 start time != first_on_msg_epoch_ms + for msg in channel_to_track[0]: + if msg.type == "note_on" and msg.velocity > 0: + msg.time = 0 + break + else: + msg.time = 0 + + mid = mido.MidiFile(type=1) + mid.ticks_per_beat = 500 + + for channel, track in channel_to_track.items(): + track.insert(0, mido.MetaMessage("set_tempo", tempo=500000, time=0)) + track.insert( + 0, + mido.Message("program_change", program=0, channel=channel, time=0), + ) + mid.tracks.append(track) + + return mid + + +def _find_divergence( + prev_context: list, + curr_context: list, + logger: logging.Logger, +): + agreement_index = 0 + for prev_val, curr_val in zip(prev_context, curr_context): + if prev_val == curr_val: + agreement_index += 1 + else: + logger.info( + f"Found divergence at position {agreement_index + 1}: {curr_val}, {prev_val}" + ) + break + + return agreement_index, curr_context[agreement_index:] + + +# There is an error here if curr_context < prev_context +@torch.inference_mode() +def chunked_prefill( + model: TransformerLM, + tokenizer: AbsTokenizer, + prev_context: list, + curr_context: list, + full: bool = False, +): + + assert isinstance(curr_context[0], int) + assert tokenizer.pad_id not in prev_context + assert tokenizer.pad_id not in curr_context + + logger = get_logger("PREFILL") + while True: + prefill_idx, prefill_toks = _find_divergence( + prev_context, curr_context, logger=logger + ) + num_prefill_toks = len(prefill_toks) + logger.debug(f"Tokens to prefill: {len(prefill_toks)}") + + if num_prefill_toks > PREFILL_CHUNK_SIZE: + logger.debug( + f"Prefilling {PREFILL_CHUNK_SIZE} tokens from idx={prefill_idx}" + ) + + prefill( + model, + idxs=torch.tensor( + [prefill_toks[:PREFILL_CHUNK_SIZE]], + device="cuda", + dtype=torch.int, + ), + input_pos=torch.arange( + prefill_idx, + prefill_idx + PREFILL_CHUNK_SIZE, + device="cuda", + dtype=torch.int, + ), + ) + prev_context = curr_context[: prefill_idx + PREFILL_CHUNK_SIZE] + + elif num_prefill_toks > 0 and full is True: + logger.debug( + f"Prefilling (force) {num_prefill_toks} tokens from idx={prefill_idx}" + ) + prefill_toks += (PREFILL_CHUNK_SIZE - len(prefill_toks)) * [ + tokenizer.pad_id + ] + prefill( + model, + idxs=torch.tensor( + [prefill_toks], device="cuda", dtype=torch.int + ), + input_pos=torch.arange( + prefill_idx, + prefill_idx + PREFILL_CHUNK_SIZE, + device="cuda", + dtype=torch.int, + ), + ) + prev_context = curr_context + break + else: + break + + logger.info( + f"KV stored up to idx={max(0, len(prev_context)- 1)} (curr_context_len={len(curr_context)})" + ) + + return prev_context + + +def continuous_prefill( + model: TransformerLM, + msgs: list, + received_messages_queue: queue.Queue, + prev_context: list[int], +): + tokenizer = AbsTokenizer() + logger = get_logger("PREFILL") + msg_cnt = 0 + seen_sentinel = False + + while seen_sentinel is False: + while seen_sentinel is False: + try: + msg = received_messages_queue.get_nowait() + except queue.Empty: + break + else: + if msg is None: + logger.info("Seen sentinel in message received messages") + seen_sentinel = True + else: + msgs.append(msg) + msg_cnt += 1 + + if (msg_cnt >= 5 or seen_sentinel) and len(msgs) > 10: + midi = convert_msgs_to_midi(msgs=msgs) + midi_dict = MidiDict(**midi_to_dict(midi)) + curr_context = tokenizer.encode( + tokenizer.tokenize(midi_dict, add_dim_tok=False) + ) + prev_context = chunked_prefill( + model=model, + tokenizer=tokenizer, + prev_context=prev_context, + curr_context=curr_context, + full=False, + ) + msg_cnt = 0 + else: + time.sleep(0.01) + + return msgs, prev_context + + +def capture_and_update_kv( + model: TransformerLM, + msgs: list, + prev_context: list, + control_sentinel: threading.Event, + midi_input_port: str, + midi_capture_channel: int, + midi_control_signal: int | None = None, + midi_through_port: str | None = None, + first_msg_epoch_time_ms: int | None = None, +): + received_messages_queue = queue.Queue() + results_queue = queue.Queue() + capture_midi_thread = threading.Thread( + target=capture_midi_input, + kwargs={ + "midi_input_port": midi_input_port, + "control_sentinel": control_sentinel, + "received_messages_queue": received_messages_queue, + "midi_capture_channel": midi_capture_channel, + "midi_control_signal": midi_control_signal, + "midi_through_port": midi_through_port, + "first_msg_epoch_time_ms": first_msg_epoch_time_ms, + "results_queue": results_queue, + }, + ) + capture_midi_thread.start() + + msgs, prev_context = continuous_prefill( + model=model, + msgs=msgs, + received_messages_queue=received_messages_queue, + prev_context=prev_context, + ) + capture_midi_thread.join() + first_on_msg_epoch_ms, num_active_pitches = results_queue.get() + + return msgs, prev_context, first_on_msg_epoch_ms, num_active_pitches + + +def capture_midi_input( + midi_input_port: str, + control_sentinel: threading.Event, + received_messages_queue: queue.Queue, + midi_capture_channel: int, + results_queue: queue.Queue, + midi_control_signal: int | None = None, + midi_through_port: str | None = None, + first_msg_epoch_time_ms: int | None = None, +): + logger = get_logger("CAPTURE") + active_pitches = set() + first_on_msg_epoch_ms = None + prev_msg_epoch_time_ms = first_msg_epoch_time_ms # + + logger.info(f"Listening on MIDI port: '{midi_input_port}'") + logger.info(f"Using MIDI control signal: {midi_control_signal}") + if midi_through_port is not None: + logger.info(f"Sending through on MIDI port: '{midi_through_port}'") + + with ExitStack() as stack: + midi_input = stack.enter_context(mido.open_input(midi_input_port)) + midi_through = ( + stack.enter_context(mido.open_output(midi_through_port)) + if midi_through_port + else None + ) + + while not control_sentinel.is_set(): + msg = midi_input.receive(block=False) + + if msg is None: + time.sleep(0.001) + continue + + if prev_msg_epoch_time_ms is None: + msg_time_ms = 0 + else: + msg_time_ms = get_epoch_time_ms() - prev_msg_epoch_time_ms + + prev_msg_epoch_time_ms = get_epoch_time_ms() + msg.time = msg_time_ms + msg.channel = midi_capture_channel + logger.info(f"Received message: [{msg}]") + + if msg.is_meta is True or msg.type == "program_change": + continue + + if ( + msg.type == "note_on" and msg.velocity == 0 + ) or msg.type == "note_off": + active_pitches.discard(msg.note) + received_messages_queue.put(msg) + if midi_through is not None: + midi_through.send(msg) + elif msg.type == "note_on" and msg.velocity > 0: + if first_on_msg_epoch_ms is None: + first_on_msg_epoch_ms = get_epoch_time_ms() + + active_pitches.add(msg.note) + received_messages_queue.put(msg) + if midi_through is not None: + midi_through.send(msg) + elif msg.type == "control_change" and msg.control == 64: + received_messages_queue.put(msg) + elif ( + msg.type == "control_change" + and msg.control == midi_control_signal + and msg.value > 0 + ): + control_sentinel.set() + + logger.info("Control signal seen") + logger.info(f"Active pitches: {active_pitches}") + num_active_pitches = len(active_pitches) + + if active_pitches: + pitch = active_pitches.pop() + msg = mido.Message( + type="note_on", + note=pitch, + velocity=0, + channel=midi_capture_channel, + time=get_epoch_time_ms() - prev_msg_epoch_time_ms, + ) + received_messages_queue.put(msg) + if midi_through is not None: + midi_through.send(msg) + + while active_pitches: + pitch = active_pitches.pop() + msg = mido.Message( + type="note_on", + note=pitch, + velocity=0, + channel=midi_capture_channel, + time=0, + ) + received_messages_queue.put(msg) + if midi_through is not None: + midi_through.send(msg) + + # Turn off pedal + msg = mido.Message( + type="control_change", + control=64, + value=0, + channel=midi_capture_channel, + time=0, + ) + received_messages_queue.put(msg) + if midi_through is not None: + midi_through.send(msg) + + received_messages_queue.put(None) # Sentinel + results_queue.put((first_on_msg_epoch_ms, num_active_pitches)) + + +def play_midi_file(midi_port: str, midi_path: str): + logger = get_logger("FILE") + logger.info(f"Playing file at {midi_path} on MIDI port '{midi_port}'") + time.sleep(1) + active_pitches = [] + with mido.open_output(midi_port) as output_port: + for msg in mido.MidiFile(midi_path).play(): + if msg.type == "note_on" and msg.velocity > 0: + if msg.note in active_pitches: + _off_msg = copy.deepcopy(msg) + _off_msg.velocity = 0 + output_port.send(_off_msg) + else: + active_pitches.append(msg.note) + elif msg.type == "note_off" or ( + msg.type == "note_on" and msg.velocity == 0 + ): + if msg.note in active_pitches: + active_pitches.remove(msg.note) + + logger.debug(f"{msg}") + output_port.send(msg) + + +def listen_for_keypress_control_signal( + control_sentinel: threading.Event, + end_sentinel: threading.Event, +): + logger = get_logger("KEYBOARD") + while True: + time.sleep(1) + _input = input() + logger.info(f'Keypress seen "{_input}"') + control_sentinel.set() + + if _input == "e": + end_sentinel.set() + + +# TODO: Not tested +def listen_for_midi_control_signal( + midi_input_port: str, + control_sentinel: threading.Event, + end_sentinel: threading.Event, + midi_control_signal: int | None = None, + midi_end_signal: int | None = None, +): + with mido.open_input(midi_input_port) as midi_input: + while True: + msg = midi_input.receive(block=False) + if msg is None: + time.sleep(0.01) + elif ( + msg.type == "control_change" + and msg.control == midi_control_signal + and msg.value > 0 + ): + control_sentinel.set() + elif ( + msg.type == "control_change" + and msg.control == midi_end_signal + and msg.value > 0 + ): + control_sentinel.set() + end_sentinel.set() + + +def parse_args(): + argp = argparse.ArgumentParser() + argp.add_argument("-cp", help="path to model checkpoint") + argp.add_argument("-midi_in", required=False, help="MIDI input port") + argp.add_argument("-midi_out", required=True, help="MIDI output port") + argp.add_argument( + "-midi_through", + required=False, + help="MIDI through port for received input", + ) + argp.add_argument( + "-midi_path", + required=False, + help="Use MIDI file instead of MIDI input port", + ) + argp.add_argument( + "-midi_control_signal", + type=int, + help="MIDI control change message for AI takeover", + ) + argp.add_argument( + "-midi_end_signal", + type=int, + help="MIDI control change message to generate ending", + ) + argp.add_argument( + "-temp", + help="sampling temperature value", + type=float, + required=False, + default=0.95, + ) + argp.add_argument( + "-min_p", + help="sampling min_p value", + type=float, + required=False, + default=0.03, + ) + argp.add_argument( + "-cfg", + help="sampling cfg gamma value", + type=float, + required=False, + ) + argp.add_argument( + "-metadata", + nargs=2, + metavar=("KEY", "VALUE"), + action="append", + help="manually add metadata key-value pair when sampling", + ) + argp.add_argument( + "-save_path", + type=str, + required=False, + help="Path to save complete MIDI file", + ) + + return argp.parse_args() + + +# TODO: Need functionality for handing case where we run out of model context +# TODO: Make sure channel=9 (drum) case is covered +def main(): + args = parse_args() + logger = get_logger() + tokenizer = AbsTokenizer() + model = load_model(checkpoint_path=args.cp) + model = compile_model(model=model, max_seq_len=MAX_SEQ_LEN) + + assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in + if args.midi_path: + midi_input_port = "Midi Through:Midi Through Port-0" + play_file_thread = threading.Thread( + target=play_midi_file, + args=(midi_input_port, args.midi_path), + daemon=True, + ) + play_file_thread.start() + else: + midi_input_port = args.midi_in + + control_sentinel = threading.Event() + end_sentinel = threading.Event() + keypress_thread = threading.Thread( + target=listen_for_keypress_control_signal, + args=[control_sentinel, end_sentinel], + daemon=True, + ) + midi_control_thread = threading.Thread( + target=listen_for_midi_control_signal, + kwargs={ + "midi_input_port": midi_input_port, + "control_sentinel": control_sentinel, + "end_sentinel": end_sentinel, + "midi_control_signal": args.midi_control_signal, + "midi_end_signal": args.midi_end_signal, + }, + daemon=True, + ) + keypress_thread.start() + midi_control_thread.start() + + msgs, prev_context, first_on_msg_epoch_ms, num_active_pitches = ( + capture_and_update_kv( + model=model, + msgs=[], + prev_context=[], + control_sentinel=control_sentinel, + midi_input_port=midi_input_port, + midi_control_signal=args.midi_control_signal, + midi_through_port=args.midi_through, + midi_capture_channel=0, + ) + ) + + itt = 0 + while True: + control_sentinel.clear() + msgs = stream_msgs( + model=model, + tokenizer=tokenizer, + msgs=msgs, + prev_context=prev_context, + midi_output_port=args.midi_out, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + control_sentinel=control_sentinel, + temperature=args.temp, + min_p=args.min_p, + num_preceding_active_pitches=num_active_pitches, + midi_stream_channel=itt, + is_ending=False, + ) + + itt += 1 + control_sentinel.clear() + if end_sentinel.is_set(): + break + + msgs, prev_context, _, num_active_pitches = capture_and_update_kv( + model=model, + msgs=msgs, + prev_context=prev_context, + control_sentinel=control_sentinel, + midi_input_port=midi_input_port, + midi_control_signal=args.midi_control_signal, + midi_through_port=args.midi_through, + midi_capture_channel=itt, + first_msg_epoch_time_ms=first_on_msg_epoch_ms, + ) + + # TODO: There is a bug with the token somewhere? + msgs = stream_msgs( + model=model, + tokenizer=tokenizer, + msgs=msgs, + prev_context=prev_context, + midi_output_port=args.midi_out, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + control_sentinel=control_sentinel, + temperature=args.temp / 2, + min_p=args.min_p, + num_preceding_active_pitches=num_active_pitches, + midi_stream_channel=itt, + is_ending=True, + ) + + if args.save_path: + logger.info(f"Saving result to {args.save_path}") + midi = convert_msgs_to_midi(msgs=msgs) + midi.save(args.save_path) + + +if __name__ == "__main__": + main() diff --git a/demo/demo.sh b/demo/demo.sh new file mode 100644 index 00000000..03f2d897 --- /dev/null +++ b/demo/demo.sh @@ -0,0 +1,12 @@ +MID_PATH="/home/loubb/Dropbox/shared/demo.mid" + +python /home/loubb/work/aria/demo/demo.py \ + -cp /mnt/ssd1/aria/v2/medium-dedupe-pt-cont2/checkpoints/epoch18_step0/model.safetensors \ + -midi_path ${MID_PATH} \ + -midi_out "Midi Through:Midi Through Port-1" \ + -midi_through "Midi Through:Midi Through Port-2" \ + -save_path /home/loubb/Dropbox/shared/output.mid \ + -midi_control_signal 66 \ + -midi_end_signal 67 \ + -temp 0.98 \ + -min_p 0.02 \ No newline at end of file diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py new file mode 100644 index 00000000..5f348201 --- /dev/null +++ b/demo/demo_mlx.py @@ -0,0 +1,1641 @@ +#!/usr/bin/env python3 + +import argparse +import os +import time +import uuid +import copy +import random +import logging +import threading +import queue +import copy +import mido +import torch + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from ariautils.midi import MidiDict, midi_to_dict +from ariautils.tokenizer import AbsTokenizer +from aria.inference.model_mlx import TransformerLM +from aria.model import ModelConfig +from aria.config import load_model_config + +DTYPE = mx.float32 +MAX_SEQ_LEN = 2048 +PREFILL_CHUNK_SIZE_L = 128 +PREFILL_CHUNK_SIZE = 16 +RECALC_DUR_PREFILL_CHUNK_SIZE = 8 +RECALC_DUR_BUFFER_MS = 100 + +BEAM_WIDTH = 3 +TIME_TOK_WEIGHTING = -5 +FIRST_ONSET_BUFFER_MS = -100 # Controls onset timing for first generated note + +# HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS +# HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg +# HARDWARE: All messages are sent HARDWARE_OUTPUT_LATENCY_MS early + +# C4DM Disklavier: +# MIN_NOTE_DELTA_MS = 40 +# MIN_NOTE_LEN_MS = 50 +# HARDWARE_INPUT_LATENCY_MS = 50 +# HARDWARE_OUTPUT_LATENCY_MS = 150 + +# Pianoteq +MIN_NOTE_DELTA_MS = 0 +MIN_NOTE_LEN_MS = 0 +HARDWARE_INPUT_LATENCY_MS = 0 +HARDWARE_OUTPUT_LATENCY_MS = 0 + +MAX_STREAM_DELAY_MS = 25 + +file_handler = logging.FileHandler("./demo.log", mode="w") +file_handler.setLevel(logging.DEBUG) + + +def get_logger(name: str | None = None) -> logging.Logger: + logger = logging.getLogger(name) + if not logger.handlers: + logger.propagate = False + logger.setLevel(logging.DEBUG) + + class MillisecondFormatter(logging.Formatter): + def formatTime(self, record, datefmt=None): + created_ms = int(record.created * 1000) + return str(created_ms) + + if name is not None: + formatter = MillisecondFormatter( + "%(asctime)s: [%(levelname)s] [%(name)s] %(message)s" + ) + else: + formatter = MillisecondFormatter( + "%(asctime)s: [%(levelname)s] %(message)s" + ) + + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + logger.addHandler(ch) + + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +def get_epoch_time_ms() -> int: + return round(time.time() * 1000) + + +def prefill( + model: TransformerLM, + idxs: mx.array, + input_pos: mx.array, + pad_idxs: mx.array | None = None, +) -> mx.array: + # pad_idxs is only needed for prepended pad tokens + logits = model( + idxs=idxs, + input_pos=input_pos, + offset=input_pos[0], + pad_idxs=pad_idxs, + ) + + return logits + + +def decode_one( + model: TransformerLM, + idxs: mx.array, + input_pos: mx.array, + pad_idxs: mx.array | None = None, +) -> mx.array: + # pad_idxs is only needed for prepended pad tokens + assert input_pos.shape[-1] == 1 + + logits = model( + idxs=idxs, + input_pos=input_pos, + offset=input_pos[0], + pad_idxs=pad_idxs, + )[:, -1] + + return logits + + +def sample_min_p(probs: mx.array, p_base: float): + """See - https://arxiv.org/pdf/2407.01082""" + + p_max = mx.max(probs, axis=-1, keepdims=True) + p_scaled = p_base * p_max + mask = probs >= p_scaled + + masked_probs = mx.where(~mask, mx.zeros_like(probs), probs) + sum_masked_probs = mx.sum(masked_probs, axis=-1, keepdims=True) + masked_probs_normalized = masked_probs / sum_masked_probs + + # Dumb workaround for mlx not having categorical probs sampler + next_token = mx.array( + torch.multinomial( + torch.from_numpy(np.array(masked_probs_normalized)), num_samples=1 + ), + dtype=mx.int32, + ) + + return next_token + + +def _compile_prefill( + model: TransformerLM, + logger: logging.Logger, + chunk_size: int, +): + assert chunk_size > 1 + + compile_start_time_s = time.time() + logger.info(f"Compiling prefill (chunk_size={chunk_size})") + for _ in range(5): + mx.eval( + prefill( + model, + idxs=mx.ones([1, chunk_size], dtype=mx.int32), + input_pos=mx.arange( + MAX_SEQ_LEN - (chunk_size + 1), + MAX_SEQ_LEN - 1, + dtype=mx.int32, + ), + ) + ) + + logger.info( + f"Finished compiling - took {time.time() - compile_start_time_s:.4f} seconds" + ) + + bench_start_time_s = time.time() + mx.eval( + prefill( + model, + idxs=mx.ones([1, chunk_size], dtype=mx.int32), + input_pos=mx.arange(0, chunk_size, dtype=mx.int32), + ) + ) + bench_end_time_s = time.time() + bench_ms = 1e3 * (bench_end_time_s - bench_start_time_s) + bench_its = 1000 / bench_ms + logger.info( + f"Compiled prefill benchmark: {bench_ms:.2f} ms/it ({bench_its:.2f} it/s)" + ) + + return model + + +def _compile_decode_one( + model: TransformerLM, + logger: logging.Logger, +): + # Don't need to explicitly compile with mlx, instead we are just precalculating + # the computation graphs for different shapes + compile_start_time_s = time.time() + for _ in range(5): + mx.eval( + decode_one( + model, + idxs=mx.array([[random.randint(0, 20)]], dtype=mx.int32), + input_pos=mx.array([MAX_SEQ_LEN - 1], dtype=mx.int32), + ), + ) + logger.info( + f"Finished compiling - took {time.time() - compile_start_time_s:.4f} seconds" + ) + + bench_start_time_s = time.time() + mx.eval( + decode_one( + model, + idxs=mx.array([[0]], dtype=mx.int32), + input_pos=mx.array([0], dtype=mx.int32), + ) + ) + bench_end_time_s = time.time() + bench_ms = 1e3 * (bench_end_time_s - bench_start_time_s) + bench_its = 1000 / bench_ms + logger.info( + f"Compiled decode_one benchmark: {bench_ms:.2f} ms/it ({bench_its:.2f} it/s)" + ) + + return model + + +def compile_model(model: TransformerLM): + logger = get_logger() + + model.eval() + model.setup_cache( + batch_size=1, + max_seq_len=MAX_SEQ_LEN, + dtype=DTYPE, + ) + + model = _compile_decode_one(model=model, logger=logger) + for chunk_size in list( + { + PREFILL_CHUNK_SIZE_L, + PREFILL_CHUNK_SIZE, + RECALC_DUR_PREFILL_CHUNK_SIZE, + } + ): + model = _compile_prefill( + model=model, logger=logger, chunk_size=chunk_size + ) + + return model + + +def load_model( + checkpoint_path: str, +): + logger = get_logger() + + tokenizer = AbsTokenizer() + model_config = ModelConfig(**load_model_config("medium-emb")) + model_config.set_vocab_size(tokenizer.vocab_size) + + logging.info(f"Loading model weights from {checkpoint_path}") + + init_start_time_s = time.time() + model = TransformerLM(model_config) + model.load_weights(checkpoint_path, strict=False) + nn.quantize(model.model, group_size=64, bits=8) + model.eval() + + logger.info( + f"Finished initializing model - took {time.time() - init_start_time_s:.4f} seconds" + ) + + return model + + +def _first_bad_dur_index( + tokenizer: AbsTokenizer, + priming_seq: list, + pred_ids: list, + chunk_start: int, + last_offset_ms: int, + logger: logging.Logger, +): + num_time_toks = priming_seq[:chunk_start].count(tokenizer.time_tok) + local_onset_ms = tokenizer.calc_length_ms( + priming_seq[: chunk_start + 1], onset=True + ) # chunk_start + 1 to account for possibly truncated dur token + logger.debug(f"Starting from local onset {local_onset_ms}") + + for pos, tok_id in enumerate( + pred_ids[: len(priming_seq) - chunk_start], start=chunk_start + ): + prim_tok = priming_seq[pos] # Should never error? + pred_tok = tokenizer.id_to_tok[tok_id] + logger.debug(f"prim={prim_tok}, pred={pred_tok}") + + if isinstance(prim_tok, tuple) and prim_tok[0] == "onset": + local_onset_ms = num_time_toks * 5000 + prim_tok[1] + elif prim_tok == tokenizer.time_tok: + num_time_toks += 1 + elif isinstance(prim_tok, tuple) and prim_tok[0] == "dur": + dur_true = prim_tok[1] + dur_pred = pred_tok[1] + if dur_pred > dur_true and ( + local_onset_ms + dur_true + >= last_offset_ms - RECALC_DUR_BUFFER_MS + ): + logger.info( + f"Found token to resample at {pos}: {prim_tok} -> {pred_tok}" + ) + return pos + + return None + + +def recalc_dur_tokens_chunked( + model: TransformerLM, + priming_seq: list, + enc_seq: mx.array, + tokenizer: AbsTokenizer, + start_idx: int, +): + # Speculative-decoding inspired duration re-calculation + assert start_idx > 0 + logger = get_logger("GENERATE") + + priming_len = len(priming_seq) + last_offset = tokenizer.calc_length_ms(priming_seq, onset=False) + logger.debug( + f"Using threshold for duration recalculation: {last_offset - RECALC_DUR_BUFFER_MS}" + ) + + idx = start_idx + while idx <= priming_len: + end_idx = idx + RECALC_DUR_PREFILL_CHUNK_SIZE + + window_ids = mx.array( + enc_seq[:, idx - 1 : end_idx - 1].tolist(), + dtype=mx.int32, + ) + window_pos = mx.arange(idx - 1, end_idx - 1, dtype=mx.int32) + + logger.info( + f"Recalculating chunked durations for positions: {idx-1} - {end_idx-2}" + ) + + logits = prefill(model, idxs=window_ids, input_pos=window_pos) + pred_ids = mx.argmax(logits, axis=-1).flatten().tolist() + + logger.debug(f"Inserted: {tokenizer.decode(window_ids[0].tolist())}") + logger.debug(f"Positions: {window_pos.tolist()}") + logger.debug(f"Predictions: {tokenizer.decode(pred_ids)}") + + bad_pos = _first_bad_dur_index( + tokenizer=tokenizer, + priming_seq=priming_seq, + pred_ids=pred_ids, + chunk_start=idx, + last_offset_ms=last_offset, + logger=logger, + ) + + if bad_pos is None: + idx = end_idx + else: + new_id = pred_ids[bad_pos - idx] + enc_seq[0, bad_pos] = new_id + priming_seq[bad_pos] = tokenizer.id_to_tok[new_id] + idx = bad_pos + 1 + + next_logits = logits[:, priming_len - idx] + + return enc_seq, priming_seq, next_logits + + +def decode_first_tokens( + model: TransformerLM, + first_token_logits: mx.array, + enc_seq: mx.array, + priming_seq: list, + tokenizer: AbsTokenizer, + generated_tokens_queue: queue.Queue, + first_on_msg_epoch_ms: int, +): + logger = get_logger("GENERATE") + + # buffer_ms determines how far in the past to start generating notes + buffer_ms = FIRST_ONSET_BUFFER_MS - HARDWARE_OUTPUT_LATENCY_MS + time_tok_id = tokenizer.tok_to_id[tokenizer.time_tok] + eos_tok_id = tokenizer.tok_to_id[tokenizer.eos_tok] + dim_tok_id = tokenizer.tok_to_id[tokenizer.dim_tok] + + logits = first_token_logits + time_since_first_onset_ms = get_epoch_time_ms() - first_on_msg_epoch_ms + idx = len(priming_seq) + 1 + + num_time_toks_required = (time_since_first_onset_ms + buffer_ms) // 5000 + num_time_toks_in_priming_seq = priming_seq.count(tokenizer.time_tok) + num_time_toks_to_add = num_time_toks_required - num_time_toks_in_priming_seq + + logger.info(f"Time since first onset: {time_since_first_onset_ms}ms") + logger.info(f"Using first note-onset buffer: {buffer_ms}ms") + + while num_time_toks_to_add > 0: + generated_tokens_queue.put(tokenizer.time_tok) + logits = decode_one( + model, + idxs=mx.array([[time_tok_id]], dtype=mx.int32), + input_pos=mx.array([idx - 1], dtype=mx.int32), + ) + + logger.info(f"Inserted time_tok at position {idx-1}") + num_time_toks_to_add -= 1 + enc_seq[:, idx - 1] = time_tok_id + idx += 1 + + logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") + + # MLX doesn't have a equivalent of torch topk + log_probs = nn.log_softmax(logits, axis=-1) + top_ids = mx.argsort(log_probs, axis=-1)[0, -BEAM_WIDTH:] + top_log_probs = log_probs[0, top_ids] + + # top_log_probs are sorted in ascending order + if time_tok_id not in top_ids.tolist(): + top_ids[0] = time_tok_id + top_log_probs[0] = log_probs[0, time_tok_id] + + _time_tok_idx = top_ids.tolist().index(time_tok_id) + top_log_probs[_time_tok_idx] += TIME_TOK_WEIGHTING + + top_toks = [tokenizer.id_to_tok[id] for id in top_ids.tolist()] + + logger.debug(f"Calculated top {BEAM_WIDTH} tokens={top_toks}") + logger.debug(f"Calculated top {BEAM_WIDTH} scores={top_log_probs.tolist()}") + + masked_onset_ids = [ + tokenizer.tok_to_id[tok] + for tok in tokenizer.onset_tokens + if tok[1] < ((time_since_first_onset_ms + buffer_ms) % 5000) + ] + + logger.debug( + f"Masking onsets for {len(masked_onset_ids)} tokens ({time_since_first_onset_ms + buffer_ms})" + ) + + best_score = float("-inf") + for i in range(BEAM_WIDTH): + tok = top_toks[i] + tok_id = top_ids[i].item() + tok_log_prob = top_log_probs[i] + + next_logits = decode_one( + model, + idxs=mx.array([[tok_id]], dtype=mx.int32), + input_pos=mx.array([idx - 1], dtype=mx.int32), + ) + logger.debug( + f"Sampled logits for positions {idx} by inserting {tok} at position {idx-1}" + ) + + next_log_probs = nn.log_softmax(next_logits, axis=-1) + next_log_probs[:, masked_onset_ids] = float("-inf") + next_log_probs[:, eos_tok_id] = float("-inf") + next_log_probs[:, dim_tok_id] = float("-inf") + if tok_id == time_tok_id: + next_log_probs[:, time_tok_id] = float("-inf") + + next_tok_log_prob = mx.max(next_log_probs, axis=-1) + next_tok_id = mx.argmax(next_log_probs, axis=-1) + next_tok = tokenizer.id_to_tok[next_tok_id.item()] + score = tok_log_prob + next_tok_log_prob + + logger.info( + f"Calculated tuple {(tok, next_tok)} with scores {(tok_log_prob.item(), next_tok_log_prob.item())} (combined={score.item()})" + ) + + if score > best_score: + best_tok_id_1, best_tok_id_2 = tok_id, next_tok_id.item() + best_tok_1, best_tok_2 = ( + tokenizer.id_to_tok[best_tok_id_1], + tokenizer.id_to_tok[best_tok_id_2], + ) + best_score = score + + logger.info( + f"Chose tuple {(best_tok_1, best_tok_2)} with score {best_score.item()}" + ) + + enc_seq[:, idx - 1] = best_tok_id_1 + enc_seq[:, idx] = best_tok_id_2 + generated_tokens_queue.put(tokenizer.id_to_tok[best_tok_id_1]) + generated_tokens_queue.put(tokenizer.id_to_tok[best_tok_id_2]) + + mx.eval( + decode_one( + model, + idxs=mx.array([[best_tok_id_1]], dtype=mx.int32), + input_pos=mx.array([idx - 1], dtype=mx.int32), + ) + ) + + logger.info( + f"Updated KV-Cache by re-inserting {best_tok_1} at position {idx-1}" + ) + logger.info( + f"Inserted {best_tok_2} at position {idx} without updating KV-Cache" + ) + + return enc_seq, idx + 1 + + +def decode_tokens( + model: TransformerLM, + enc_seq: mx.array, + tokenizer: AbsTokenizer, + control_sentinel: threading.Event, + generated_tokens_queue: queue.Queue, + idx: int, + temperature: float, + min_p: float, + is_ending: bool, +): + logger = get_logger("GENERATE") + logger.info( + f"Using sampling parameters: temperature={temperature}, min_p={min_p}" + ) + + if control_sentinel.is_set(): + control_sentinel.clear() + + while (not control_sentinel.is_set()) and idx < MAX_SEQ_LEN: + decode_one_start_time_s = time.time() + prev_tok_id = enc_seq[0, idx - 1] + prev_tok = tokenizer.id_to_tok[prev_tok_id.item()] + + logits = decode_one( + model, + idxs=mx.array([[prev_tok_id]], dtype=mx.int32), + input_pos=mx.array([idx - 1], dtype=mx.int32), + ) + + logger.debug( + f"Sampled logits for positions {idx} by inserting {prev_tok} at position {idx-1}" + ) + + logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + if is_ending is False: + logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") + + for dur_ms in range(0, MIN_NOTE_LEN_MS, 10): + logits[:, tokenizer.tok_to_id[("dur", dur_ms)]] = float("-inf") + + if temperature > 0.0: + probs = mx.softmax(logits / temperature, axis=-1) + next_token_ids = sample_min_p(probs, min_p).flatten() + else: + next_token_ids = mx.argmax(logits, axis=-1).flatten() + + enc_seq[:, idx] = next_token_ids + next_token = tokenizer.id_to_tok[next_token_ids[0].item()] + logger.debug( + f"({(time.time() - decode_one_start_time_s)*1000:.2f}ms) {idx}: {next_token}" + ) + + if next_token == tokenizer.eos_tok: + logger.info("EOS token produced") + generated_tokens_queue.put(next_token) + return + else: + generated_tokens_queue.put(next_token) + idx += 1 + + logger.info(f"Finished generating: {idx}") + generated_tokens_queue.put(None) + + +def generate_tokens( + priming_seq: list, + tokenizer: AbsTokenizer, + model: TransformerLM, + prev_context: list[int], + control_sentinel: threading.Event, + generated_tokens_queue: queue.Queue, + num_preceding_active_pitches: int, + first_on_msg_epoch_ms: int, + temperature: float = 0.98, + min_p: float = 0.03, + is_ending: bool = False, +): + logger = get_logger("GENERATE") + + generate_start_s = time.time() + priming_seq_len = len(priming_seq) + + start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches - 1) + enc_seq = mx.array( + [ + tokenizer.encode( + priming_seq + + [tokenizer.pad_tok] * (MAX_SEQ_LEN - len(priming_seq)) + ) + ], + dtype=mx.int32, + ) + + logger.debug(f"Priming sequence {priming_seq}") + logger.info(f"Priming sequence length: {priming_seq_len}") + logger.info(f"Prefilling up to (and including) position: {start_idx-1}") + + prefill_start_s = time.time() + chunked_prefill( + model=model, + tokenizer=tokenizer, + prev_context=prev_context, + curr_context=enc_seq[0, :start_idx].tolist(), + full=True, + ) + + logger.info( + f"Prefill took {(time.time() - prefill_start_s) * 1000:.2f} milliseconds" + ) + logger.info(f"Starting duration recalculation from position: {start_idx-1}") + + recalculate_dur_start_s = time.time() + enc_seq, priming_seq, next_token_logits = recalc_dur_tokens_chunked( + model=model, + priming_seq=priming_seq, + enc_seq=enc_seq, + tokenizer=tokenizer, + start_idx=start_idx, + ) + + logger.info( + f"Recalculating durations took {(time.time() - recalculate_dur_start_s) * 1000:.2f} milliseconds" + ) + + decode_first_s = time.time() + enc_seq, idx = decode_first_tokens( + model=model, + first_token_logits=next_token_logits, + enc_seq=enc_seq, + priming_seq=priming_seq, + tokenizer=tokenizer, + generated_tokens_queue=generated_tokens_queue, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + ) + + logger.info( + f"Decode first two tokens took {(time.time() - decode_first_s) * 1000:.2f} milliseconds" + ) + logger.info( + f"Time to first token took {(time.time() - generate_start_s) * 1000:.2f} milliseconds" + ) + + decode_tokens( + model=model, + enc_seq=enc_seq, + tokenizer=tokenizer, + control_sentinel=control_sentinel, + generated_tokens_queue=generated_tokens_queue, + idx=idx, + temperature=temperature, + min_p=min_p, + is_ending=is_ending, + ) + + +def decode_tokens_to_midi( + generated_tokens_queue: queue.Queue, + outbound_midi_msg_queue: queue.Queue, + tokenizer: AbsTokenizer, + first_on_msg_epoch_ms: int, + priming_seq_last_onset_ms: int, +): + logger = get_logger("DECODE") + + assert ( + first_on_msg_epoch_ms + priming_seq_last_onset_ms + < get_epoch_time_ms() + HARDWARE_INPUT_LATENCY_MS + ) + + logger.info(f"Priming sequence last onset: {priming_seq_last_onset_ms}") + logger.info( + f"Total time elapsed since first onset: {get_epoch_time_ms() - first_on_msg_epoch_ms}" + ) + + pitch_to_prev_msg = {} + note_buffer = [] + num_time_toks = priming_seq_last_onset_ms // 5000 + + while True: + while True: + tok = generated_tokens_queue.get() + if tok is tokenizer.eos_tok: + _uuid = uuid.uuid4() + end_msg = { + "pitch": -1, + "vel": -1, + "epoch_time_ms": offset_epoch_ms + 100, # Last note offset + "uuid": _uuid, + } # pitch=-1 denotes end_msg + outbound_midi_msg_queue.put(end_msg) + logger.info(f"Seen exit signal: EOS token") + logger.debug(f"Put message: {end_msg}") + return + + elif tok is None: + logger.info(f"Seen exit signal: Sentinel") + return + + logger.debug(f"Seen token: {tok}") + note_buffer.append(tok) + + if isinstance(tok, tuple) and tok[0] == "dur": + break + + while note_buffer and note_buffer[0] == tokenizer.time_tok: + logger.debug("Popping time_tok") + num_time_toks += 1 + note_buffer.pop(0) + + assert len(note_buffer) == 3 + logger.debug(f"Decoded note: {note_buffer}") + note_tok, onset_tok, dur_tok = note_buffer + _, pitch, vel = note_tok + _, onset = onset_tok + _, dur = dur_tok + + _uuid = uuid.uuid4() + onset_epoch_ms = first_on_msg_epoch_ms + (num_time_toks * 5000) + onset + offset_epoch_ms = onset_epoch_ms + dur + on_msg = { + "pitch": pitch, + "vel": vel, + "epoch_time_ms": onset_epoch_ms, + "uuid": _uuid, + } + off_msg = { + "pitch": pitch, + "vel": 0, + "epoch_time_ms": offset_epoch_ms, + "uuid": _uuid, + } + + # Not thread safe but in theory should be ok? + if pitch_to_prev_msg.get(pitch) is not None and MIN_NOTE_DELTA_MS > 0: + prev_on, prev_off = pitch_to_prev_msg.get(pitch) + adj_off_time = max( + min( + prev_off["epoch_time_ms"], + onset_epoch_ms - MIN_NOTE_DELTA_MS, + ), + prev_on["epoch_time_ms"], + ) + if adj_off_time != prev_off["epoch_time_ms"]: + logger.debug(f"Adjusting {prev_off}: t={adj_off_time}") + prev_off["epoch_time_ms"] = adj_off_time + prev_off["adjusted"] = True + + pitch_to_prev_msg[pitch] = [on_msg, off_msg] + + outbound_midi_msg_queue.put(on_msg) + outbound_midi_msg_queue.put(off_msg) + logger.debug(f"Put message: {on_msg}") + logger.debug(f"Put message: {off_msg}") + logger.debug(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") + + note_buffer = [] + + +def stream_midi( + inbound_midi_msg_queue: queue.Queue, + msgs: list[mido.Message], + last_channel_msg_epoch_time_ms: float, + midi_output_port: str, + control_sentinel: threading.Event, + midi_stream_channel: int, + results_queue: queue.Queue, +): + logger = get_logger("STREAM") + logger.info( + f"Sending generated messages on MIDI port: '{midi_output_port}'" + ) + logger.info( + f"Applying hardware output latency adjustment: {HARDWARE_OUTPUT_LATENCY_MS}ms" + ) + + active_pitch_uuid = {} + is_pitch_active = {} + midi_msgs = [] + + with mido.open_output(midi_output_port) as midi_out: + while not control_sentinel.is_set(): + while True: + try: + msg = inbound_midi_msg_queue.get_nowait() + except queue.Empty: + break + else: + logger.debug(f"Received message: {msg}") + midi_msgs.append(msg) + + midi_msgs = sorted( + midi_msgs, + key=lambda msg: ( + msg["epoch_time_ms"], + msg["vel"], + ), + ) + + if control_sentinel.is_set(): + break + + while midi_msgs: + # Messages are sent HARDWARE_OUTPUT_LATENCY_MS early + latency_adjusted_epoch_time_ms = ( + get_epoch_time_ms() + HARDWARE_OUTPUT_LATENCY_MS + ) + msg = midi_msgs[0] + + if ( + 0 + < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] + <= MAX_STREAM_DELAY_MS + ): + if msg["pitch"] == -1: # End msg + control_sentinel.set() + break + + mido_msg = mido.Message( + "note_on", + note=msg["pitch"], + velocity=msg["vel"], + channel=0, + time=0, + ) + + if msg["vel"] > 0: + active_pitch_uuid[msg["pitch"]] = msg["uuid"] + should_send_midi_out = True + should_append_to_msgs = True + elif msg.get("adjusted", False) is True: + should_send_midi_out = True + should_append_to_msgs = False + else: + should_send_midi_out = ( + active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] + ) + should_append_to_msgs = should_send_midi_out + + if should_send_midi_out is True: + midi_out.send(mido_msg) + is_pitch_active[msg["pitch"]] = msg["vel"] != 0 + logger.info(f"Sent message: {mido_msg}") + if should_append_to_msgs is True: + mido_msg_with_time = copy.deepcopy(mido_msg) + mido_msg_with_time.channel = midi_stream_channel + mido_msg_with_time.time = max( + 0, + msg["epoch_time_ms"] + - last_channel_msg_epoch_time_ms, + ) + last_channel_msg_epoch_time_ms = msg["epoch_time_ms"] + msgs.append(mido_msg_with_time) + + midi_msgs.pop(0) + + elif ( + latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] + > MAX_STREAM_DELAY_MS + ): + # Message occurs too far in the past + logger.debug( + f"Skipping message occurring too far ({latency_adjusted_epoch_time_ms - msg['epoch_time_ms']}ms) in the past: {msg}" + ) + midi_msgs.pop(0) + else: + # Message occurs in the future + break + + time.sleep(0.005) + + remaining_note_off_messages = [ + msg + for msg in midi_msgs + if msg["vel"] == 0 + and active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] + ] + + logger.info("Processing remaining note_off messages") + for msg in remaining_note_off_messages: + mido_msg = mido.Message( + "note_on", + note=msg["pitch"], + velocity=0, + channel=midi_stream_channel, + time=msg["epoch_time_ms"] - last_channel_msg_epoch_time_ms, + ) + midi_out.send(mido_msg) + last_channel_msg_epoch_time_ms = msg["epoch_time_ms"] + msgs.append(mido_msg) + + results_queue.put(msgs) + + +def stream_msgs( + model: TransformerLM, + tokenizer: AbsTokenizer, + msgs: list[mido.Message], + prev_context: list[int], + midi_output_port: str, + first_on_msg_epoch_ms: int, + control_sentinel: threading.Event, + temperature: float, + min_p: float, + num_preceding_active_pitches: int, + midi_stream_channel: int, + is_ending: bool = False, +): + midi = convert_msgs_to_midi(msgs=msgs) + midi_dict = MidiDict(**midi_to_dict(midi)) + priming_seq = tokenizer.tokenize(midi_dict=midi_dict, add_dim_tok=False) + priming_seq = priming_seq[: priming_seq.index(tokenizer.eos_tok)] + + if is_ending is True: + priming_seq.append(tokenizer.dim_tok) + + generated_tokens_queue = queue.Queue() + midi_messages_queue = queue.Queue() + + generate_tokens_thread = threading.Thread( + target=generate_tokens, + kwargs={ + "priming_seq": priming_seq, + "tokenizer": tokenizer, + "model": model, + "prev_context": prev_context, + "control_sentinel": control_sentinel, + "generated_tokens_queue": generated_tokens_queue, + "temperature": temperature, + "min_p": min_p, + "num_preceding_active_pitches": num_preceding_active_pitches, + "first_on_msg_epoch_ms": first_on_msg_epoch_ms, + "is_ending": is_ending, + }, + ) + generate_tokens_thread.start() + + decode_tokens_to_midi_thread = threading.Thread( + target=decode_tokens_to_midi, + kwargs={ + "generated_tokens_queue": generated_tokens_queue, + "outbound_midi_msg_queue": midi_messages_queue, + "tokenizer": tokenizer, + "first_on_msg_epoch_ms": first_on_msg_epoch_ms, + "priming_seq_last_onset_ms": tokenizer.calc_length_ms( + priming_seq, onset=True + ), + }, + ) + decode_tokens_to_midi_thread.start() + + # If ending==True then previous MIDI message on midi_stream_channel occurs + # at first_on_msg_epoch_ms. + prev_channel_msg_epoch_time_ms = ( + first_on_msg_epoch_ms + + tokenizer.calc_length_ms(priming_seq, onset=False) + if is_ending is False + else first_on_msg_epoch_ms + ) + + stream_midi_results_queue = queue.Queue() + stream_midi_thread = threading.Thread( + target=stream_midi, + kwargs={ + "inbound_midi_msg_queue": midi_messages_queue, + "msgs": msgs, + "last_channel_msg_epoch_time_ms": prev_channel_msg_epoch_time_ms, + "midi_output_port": midi_output_port, + "control_sentinel": control_sentinel, + "midi_stream_channel": midi_stream_channel, + "results_queue": stream_midi_results_queue, + }, + daemon=True, + ) + stream_midi_thread.start() + + generate_tokens_thread.join() + decode_tokens_to_midi_thread.join() + msgs = stream_midi_results_queue.get() + + if is_ending is True: + stream_midi_thread.join() + + return msgs + + +def convert_msgs_to_midi(msgs: list[mido.Message]): + channel_to_track = { + chan: mido.MidiTrack() + for chan in list(set([msg.channel for msg in msgs])) + } + + for msg in msgs: + channel_to_track[msg.channel].append(msg) + + # Workaround for possibility that track_0 start time != first_on_msg_epoch_ms + for msg in channel_to_track[0]: + if msg.type == "note_on" and msg.velocity > 0: + msg.time = 0 + break + else: + msg.time = 0 + + mid = mido.MidiFile(type=1) + mid.ticks_per_beat = 500 + + for channel, track in channel_to_track.items(): + track.insert(0, mido.MetaMessage("set_tempo", tempo=500000, time=0)) + track.insert( + 0, + mido.Message("program_change", program=0, channel=channel, time=0), + ) + mid.tracks.append(track) + + return mid + + +def _find_divergence( + prev_context: list, + curr_context: list, + logger: logging.Logger, + tokenizer: AbsTokenizer, +): + agreement_index = 0 + for prev_val, curr_val in zip(prev_context, curr_context): + if prev_val == curr_val: + agreement_index += 1 + else: + logger.info( + f"Found divergence at idx {agreement_index}: {tokenizer.id_to_tok[curr_val]}, {tokenizer.id_to_tok[prev_val]}" + ) + break + + return agreement_index, curr_context[agreement_index:] + + +def chunked_prefill( + model: TransformerLM, + tokenizer: AbsTokenizer, + prev_context: list, + curr_context: list, + full: bool = False, +): + + assert isinstance(curr_context[0], int) + assert tokenizer.pad_id not in prev_context + assert tokenizer.pad_id not in curr_context + + logger = get_logger("PREFILL") + + while True: + prefill_idx, prefill_toks = _find_divergence( + prev_context, + curr_context, + logger=logger, + tokenizer=tokenizer, + ) + num_prefill_toks = len(prefill_toks) + logger.debug(f"Tokens to prefill: {len(prefill_toks)}") + + if num_prefill_toks > PREFILL_CHUNK_SIZE_L: + logger.debug( + f"Prefilling {PREFILL_CHUNK_SIZE_L} tokens from idx={prefill_idx}" + ) + mx.eval( + prefill( + model, + idxs=mx.array( + [prefill_toks[:PREFILL_CHUNK_SIZE_L]], + dtype=mx.int32, + ), + input_pos=mx.arange( + prefill_idx, + prefill_idx + PREFILL_CHUNK_SIZE_L, + dtype=mx.int32, + ), + ) + ) + prev_context = curr_context[: prefill_idx + PREFILL_CHUNK_SIZE_L] + + elif num_prefill_toks > PREFILL_CHUNK_SIZE: + logger.debug( + f"Prefilling {PREFILL_CHUNK_SIZE} tokens from idx={prefill_idx}" + ) + mx.eval( + prefill( + model, + idxs=mx.array( + [prefill_toks[:PREFILL_CHUNK_SIZE]], + dtype=mx.int32, + ), + input_pos=mx.arange( + prefill_idx, + prefill_idx + PREFILL_CHUNK_SIZE, + dtype=mx.int32, + ), + ) + ) + prev_context = curr_context[: prefill_idx + PREFILL_CHUNK_SIZE] + + elif num_prefill_toks > 0 and full is True: + logger.debug( + f"Prefilling (force) {num_prefill_toks} tokens from idx={prefill_idx}" + ) + prefill_toks += (PREFILL_CHUNK_SIZE - len(prefill_toks)) * [ + tokenizer.pad_id + ] + mx.eval( + prefill( + model, + idxs=mx.array([prefill_toks], dtype=mx.int32), + input_pos=mx.arange( + prefill_idx, + prefill_idx + PREFILL_CHUNK_SIZE, + dtype=mx.int32, + ), + ) + ) + prev_context = curr_context + break + else: + break + + logger.info( + f"KV stored up to idx={max(0, len(prev_context)- 1)} (curr_context_len={len(curr_context)})" + ) + + return prev_context + + +def continuous_prefill( + model: TransformerLM, + msgs: list, + received_messages_queue: queue.Queue, + prev_context: list[int], +): + tokenizer = AbsTokenizer() + logger = get_logger("PREFILL") + msg_cnt = 0 + seen_sentinel = False + + while seen_sentinel is False: + while seen_sentinel is False: + try: + msg = received_messages_queue.get_nowait() + except queue.Empty: + break + else: + if msg is None: + logger.info("Seen sentinel in message received messages") + seen_sentinel = True + else: + msgs.append(msg) + msg_cnt += 1 + + if msg_cnt >= 10 or seen_sentinel: + midi = convert_msgs_to_midi(msgs=msgs) + midi_dict = MidiDict(**midi_to_dict(midi)) + + if len(midi_dict.note_msgs) > 0: + curr_context = tokenizer.encode( + tokenizer.tokenize(midi_dict, add_dim_tok=False) + ) + prev_context = chunked_prefill( + model=model, + tokenizer=tokenizer, + prev_context=prev_context, + curr_context=curr_context, + full=False, + ) + + msg_cnt = 0 + else: + time.sleep(0.01) + + return msgs, prev_context + + +def capture_and_update_kv( + model: TransformerLM, + msgs: list, + prev_context: list, + control_sentinel: threading.Event, + wait_for_close: bool, + midi_input_port: str, + midi_capture_channel: int, + midi_control_signal: int | None = None, + first_msg_epoch_time_ms: int | None = None, +): + received_messages_queue = queue.Queue() + results_queue = queue.Queue() + capture_midi_thread = threading.Thread( + target=capture_midi_input, + kwargs={ + "midi_input_port": midi_input_port, + "control_sentinel": control_sentinel, + "received_messages_queue": received_messages_queue, + "midi_capture_channel": midi_capture_channel, + "midi_control_signal": midi_control_signal, + "first_msg_epoch_time_ms": first_msg_epoch_time_ms, + "results_queue": results_queue, + "wait_for_close": wait_for_close, + }, + ) + capture_midi_thread.start() + + msgs, prev_context = continuous_prefill( + model=model, + msgs=msgs, + received_messages_queue=received_messages_queue, + prev_context=prev_context, + ) + capture_midi_thread.join() + first_on_msg_epoch_ms, num_active_pitches = results_queue.get() + + return msgs, prev_context, first_on_msg_epoch_ms, num_active_pitches + + +def capture_midi_input( + midi_input_port: str, + control_sentinel: threading.Event, + received_messages_queue: queue.Queue, + midi_capture_channel: int, + results_queue: queue.Queue, + midi_control_signal: int | None = None, + first_msg_epoch_time_ms: int | None = None, + wait_for_close: bool = False, +): + logger = get_logger("CAPTURE") + active_pitches = set() + first_on_msg_epoch_ms = None + prev_msg_epoch_time_ms = first_msg_epoch_time_ms + + logger.info(f"Listening on MIDI port: '{midi_input_port}'") + logger.info(f"Ready to capture MIDI events") + + # Clear undesired buffered notes + with mido.open_input(midi_input_port) as midi_input: + while True: + msg = midi_input.receive(block=False) + if msg is None: + break + + with mido.open_input(midi_input_port) as midi_input: + if midi_control_signal is not None: + logger.info( + f"Commencing generation upon keypress or MIDI control: {midi_control_signal}" + ) + else: + logger.info(f"Commencing generation upon keypress") + + while not control_sentinel.is_set() or ( + wait_for_close and active_pitches + ): + msg = midi_input.receive(block=False) + + if msg is None: + time.sleep(0.001) + continue + + if prev_msg_epoch_time_ms is None: + msg_time_ms = 0 + else: + msg_time_ms = get_epoch_time_ms() - prev_msg_epoch_time_ms + + prev_msg_epoch_time_ms = get_epoch_time_ms() + msg.time = msg_time_ms + msg.channel = midi_capture_channel + logger.info(f"Received message: [{msg}]") + + if msg.is_meta is True or msg.type == "program_change": + continue + + if ( + msg.type == "note_on" and msg.velocity == 0 + ) or msg.type == "note_off": + active_pitches.discard(msg.note) + received_messages_queue.put(msg) + elif msg.type == "note_on" and msg.velocity > 0: + if first_on_msg_epoch_ms is None: + first_on_msg_epoch_ms = ( + get_epoch_time_ms() - HARDWARE_INPUT_LATENCY_MS + ) + + active_pitches.add(msg.note) + received_messages_queue.put(msg) + elif msg.type == "control_change" and msg.control == 64: + received_messages_queue.put(msg) + elif ( + msg.type == "control_change" + and msg.control == midi_control_signal + and msg.value > 0 + ): + control_sentinel.set() + logger.info("Control signal seen") + + logger.info(f"Active pitches: {active_pitches}") + num_active_pitches = len(active_pitches) + + if active_pitches: + pitch = active_pitches.pop() + msg = mido.Message( + type="note_on", + note=pitch, + velocity=0, + channel=midi_capture_channel, + time=get_epoch_time_ms() - prev_msg_epoch_time_ms, + ) + received_messages_queue.put(msg) + + while active_pitches: + pitch = active_pitches.pop() + msg = mido.Message( + type="note_on", + note=pitch, + velocity=0, + channel=midi_capture_channel, + time=0, + ) + received_messages_queue.put(msg) + + # Turn off pedal + msg = mido.Message( + type="control_change", + control=64, + value=0, + channel=midi_capture_channel, + time=0, + ) + received_messages_queue.put(msg) + received_messages_queue.put(None) + results_queue.put((first_on_msg_epoch_ms, num_active_pitches)) + + +def play_midi_file( + midi_through_port: str, + midi_in_port: str, + midi_path: str, + currently_streaming_sentinel: threading.Event, +): + def _send_delayed_message(port, msg): + port.send(msg) + logger.debug(f"SENT: {msg}") + + logger = get_logger("FILE") + logger.info(f"Playing {midi_path} on through-port '{midi_through_port}'") + logger.info( + f"Simulating input to port '{midi_in_port}' with {HARDWARE_INPUT_LATENCY_MS}ms latency" + ) + + if MIN_NOTE_DELTA_MS > 0: + midi_dict = MidiDict.from_midi(midi_path) + midi_dict.enforce_gaps(min_gap_ms=MIN_NOTE_DELTA_MS) + mid = midi_dict.to_midi() + else: + mid = mido.MidiFile(midi_path) + + time.sleep(1) + with mido.open_output(midi_through_port) as through_port: + with mido.open_output(midi_in_port) as in_port: + for msg in mid.play(): + if currently_streaming_sentinel.is_set() is False and not ( + msg.type == "control_change" and msg.control == 64 + ): + through_port.send(msg) + + timer = threading.Timer( + interval=HARDWARE_INPUT_LATENCY_MS / 1000.0, + function=_send_delayed_message, + args=[in_port, msg], + ) + timer.start() + + +def listen_for_keypress_control_signal( + control_sentinel: threading.Event, + generate_ending_sentinel: threading.Event, +): + logger = get_logger("KEYBOARD") + while True: + time.sleep(5) + _input = input() + logger.info(f'Keypress seen "{_input}"') + if _input == "": + control_sentinel.set() + else: + control_sentinel.set() + generate_ending_sentinel.set() + return + + +def _listen( + midi_input_port: str, + logger: logging.Logger, + midi_control_signal: int | None = None, +): + logger.info("Listening...") + with mido.open_input(midi_input_port) as midi_input: + while True: + msg = midi_input.receive(block=False) + if msg is None: + time.sleep(0.01) + elif ( + msg.type == "control_change" + and msg.control == midi_control_signal + and msg.value >= 64 + ): + return + + +def listen_for_midi_control_signal( + midi_input_port: str, + control_sentinel: threading.Event, + midi_control_signal: int | None = None, +): + logger = get_logger("MIDI-CONTROL") + + while True: + _listen( + midi_input_port=midi_input_port, + midi_control_signal=midi_control_signal, + logger=logger, + ) + control_sentinel.set() + logger.info("Seen MIDI control signal") + time.sleep(5) + + +def parse_args(): + argp = argparse.ArgumentParser() + argp.add_argument("--checkpoint", help="path to model checkpoint") + argp.add_argument("--midi_in", required=False, help="MIDI input port") + argp.add_argument("--midi_out", required=True, help="MIDI output port") + argp.add_argument( + "--midi_through", + required=False, + help="MIDI through port for received input", + ) + argp.add_argument( + "--midi_path", + required=False, + help="Use MIDI file instead of MIDI input port", + ) + argp.add_argument( + "--midi_control_signal", + type=int, + help="MIDI control change message for AI takeover", + ) + argp.add_argument( + "--temp", + help="sampling temperature value", + type=float, + required=False, + default=0.95, + ) + argp.add_argument( + "--min_p", + help="sampling min_p value", + type=float, + required=False, + default=0.03, + ) + argp.add_argument( + "--wait_for_close", + help="wait for note-offs before generating", + action="store_true", + ) + argp.add_argument( + "--save_path", + type=str, + required=False, + help="Path to save complete MIDI file", + ) + + return argp.parse_args() + + +def main(args): + args = parse_args() + logger = get_logger() + tokenizer = AbsTokenizer() + model = load_model(checkpoint_path=args.checkpoint) + model = compile_model(model=model) + + assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in + + control_sentinel = threading.Event() + generate_ending_sentinel = threading.Event() + currently_generating_sentinel = threading.Event() + + if args.midi_through: + close_notes(args.midi_through) + if args.midi_out: + close_notes(args.midi_out) + + if args.midi_path: + midi_input_port = "IAC Driver Bus 1" + play_file_thread = threading.Thread( + target=play_midi_file, + args=( + args.midi_through, + midi_input_port, + args.midi_path, + currently_generating_sentinel, + ), + daemon=True, + ) + else: + midi_input_port = args.midi_in + play_file_thread = None + + keypress_thread = threading.Thread( + target=listen_for_keypress_control_signal, + args=[control_sentinel, generate_ending_sentinel], + daemon=True, + ) + midi_control_thread = threading.Thread( + target=listen_for_midi_control_signal, + kwargs={ + "midi_input_port": ( + args.midi_in if args.midi_in else midi_input_port + ), + "control_sentinel": control_sentinel, + "midi_control_signal": args.midi_control_signal, + }, + daemon=True, + ) + keypress_thread.start() + midi_control_thread.start() + + if play_file_thread is not None: + play_file_thread.start() + + msgs, prev_context, first_on_msg_epoch_ms, num_active_pitches = ( + capture_and_update_kv( + model=model, + msgs=[], + prev_context=[], + control_sentinel=control_sentinel, + wait_for_close=args.wait_for_close, + midi_input_port=midi_input_port, + midi_control_signal=args.midi_control_signal, + midi_capture_channel=0, + ) + ) + + curr_midi_channel = 0 + while True: + control_sentinel.clear() + currently_generating_sentinel.set() + msgs = stream_msgs( + model=model, + tokenizer=tokenizer, + msgs=msgs, + prev_context=prev_context, + midi_output_port=args.midi_out, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + control_sentinel=control_sentinel, + temperature=args.temp, + min_p=args.min_p, + num_preceding_active_pitches=num_active_pitches, + midi_stream_channel=curr_midi_channel, + is_ending=False, + ) + + curr_midi_channel += 1 + if curr_midi_channel == 9: + curr_midi_channel += 1 + + control_sentinel.clear() + if generate_ending_sentinel.is_set(): + break + else: + currently_generating_sentinel.clear() + msgs, prev_context, _, num_active_pitches = capture_and_update_kv( + model=model, + msgs=msgs, + prev_context=prev_context, + control_sentinel=control_sentinel, + wait_for_close=args.wait_for_close, + midi_input_port=midi_input_port, + midi_control_signal=args.midi_control_signal, + midi_capture_channel=curr_midi_channel, + first_msg_epoch_time_ms=first_on_msg_epoch_ms, + ) + + # Generate ending + msgs = stream_msgs( + model=model, + tokenizer=tokenizer, + msgs=msgs, + prev_context=prev_context, + midi_output_port=args.midi_out, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + control_sentinel=control_sentinel, + temperature=args.temp / 2, + min_p=args.min_p, + num_preceding_active_pitches=num_active_pitches, + midi_stream_channel=curr_midi_channel, + is_ending=True, + ) + + if args.save_path: + logger.info(f"Saving result to {args.save_path}") + midi = convert_msgs_to_midi(msgs=msgs) + midi.save(args.save_path) + + +def close_notes(midi_out_port: str): + with mido.open_output(midi_out_port) as out: + out.send(mido.Message(type="control_change", control=64, value=0)) + for note in range(128): + out.send(mido.Message("note_off", note=note, velocity=0)) + + +if __name__ == "__main__": + args = parse_args() + + try: + main(args) + except KeyboardInterrupt: + close_notes(args.midi_out) diff --git a/demo/midi-tunnel-client.py b/demo/midi-tunnel-client.py new file mode 100644 index 00000000..d4c6f839 --- /dev/null +++ b/demo/midi-tunnel-client.py @@ -0,0 +1,143 @@ +import socket +import rtmidi +import time +import subprocess +import signal +import sys +import os +import argparse + +SSH_SERVER = "home-4090.remote" +def parse_arguments(): + parser = argparse.ArgumentParser(description='MIDI UDP bridge with SSH tunnel') + parser.add_argument('-p', '--port', type=int, default=5004, + help='UDP port number (default: 5004)') + return parser.parse_args() + +def kill_existing_process(port): + # Check and kill existing process on remote server + check_command = f"ssh {SSH_SERVER} 'lsof -ti :{port}'" + try: + pid = subprocess.check_output(check_command, shell=True).decode().strip() + if pid: + print(f"Found existing process {pid} on port {port}, killing it...") + kill_command = f"ssh {SSH_SERVER} 'kill -9 {pid}'" + subprocess.run(kill_command, shell=True) + # Wait a moment for the port to be freed + time.sleep(1) + except subprocess.CalledProcessError: + # No existing process found + pass + +def setup_ssh_tunnel(port): + while True: + try: + # Kill any existing process first + kill_existing_process(port) + + # Start SSH tunnel using socat + print(f"Attempting to establish SSH tunnel on port {port}...") + ssh_command = f"ssh {SSH_SERVER} 'socat -u UDP4-RECV:{port} STDOUT'" + local_socat = f"socat -u STDIN UDP4-SEND:localhost:{port}" + + ssh_process = subprocess.Popen(ssh_command, shell=True, stdout=subprocess.PIPE) + socat_process = subprocess.Popen(local_socat, shell=True, stdin=ssh_process.stdout) + + # Check if the processes started successfully + time.sleep(1) + if ssh_process.poll() is not None: # Process terminated + raise subprocess.CalledProcessError(ssh_process.returncode, ssh_command) + + print("SSH tunnel established successfully!") + return ssh_process, socat_process + + except (subprocess.CalledProcessError, OSError) as e: + print(f"Failed to establish SSH tunnel: {str(e)}") + print("Retrying in 1 second...") + time.sleep(1) + +def create_virtual_port(port): + midi_out = rtmidi.MidiOut() + # Create a virtual MIDI port with port number in name + midi_out.open_virtual_port(f"UDP_{port}") + return midi_out + +def start_udp_listener(port): + # Create UDP socket + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.bind(('localhost', port)) + return sock + +def split_midi_messages(data): + """Split a byte array into individual MIDI messages.""" + messages = [] + data_list = list(data) + i = 0 + while i < len(data_list): + # Check if we have a status byte (most significant bit is 1) + if data_list[i] >= 0x80: + # Most MIDI messages are 3 bytes + if i + 2 < len(data_list): + messages.append(data_list[i:i+3]) + i += 3 + else: + # Handle incomplete message at end of buffer + break + else: + # Skip non-status bytes (shouldn't happen in properly formatted MIDI) + i += 1 + return messages + +def cleanup(ssh_process, socat_process, midi_out, sock): + print("\nCleaning up...") + # Kill the SSH and socat processes + if ssh_process: + os.killpg(os.getpgid(ssh_process.pid), signal.SIGTERM) + if socat_process: + socat_process.terminate() + # Close MIDI and socket + if midi_out: + midi_out.close_port() + if sock: + sock.close() + +def main(): + args = parse_arguments() + port = args.port + + ssh_process = None + socat_process = None + midi_out = None + sock = None + + try: + # Setup SSH tunnel first + print(f"Setting up SSH tunnel on port {port}...") + ssh_process, socat_process = setup_ssh_tunnel(port) + + # Setup MIDI and UDP + print(f"Creating virtual MIDI port UDP_{port}...") + midi_out = create_virtual_port(port) + print(f"Starting UDP listener on port {port}...") + sock = start_udp_listener(port) + + print(f"UDP MIDI Bridge started - listening on port {port}") + + while True: + data, addr = sock.recvfrom(1024) + if data: + # Split the data into individual MIDI messages + midi_messages = split_midi_messages(data) + for midi_message in midi_messages: + print(f"Sending MIDI message: {midi_message}") + midi_out.send_message(midi_message) + + except KeyboardInterrupt: + print("\nShutting down UDP MIDI Bridge...") + except Exception as e: + print(f"Error: {e}") + finally: + cleanup(ssh_process, socat_process, midi_out, sock) + +if __name__ == "__main__": + main() diff --git a/demo/midi-tunnel-server.py b/demo/midi-tunnel-server.py new file mode 100755 index 00000000..988e2004 --- /dev/null +++ b/demo/midi-tunnel-server.py @@ -0,0 +1,61 @@ +import rtmidi +import socket +import time +import struct +import argparse + +class MIDIRouter: + def __init__(self, midi_port="14:0", udp_port=5004): + self.midi_in = rtmidi.MidiIn() + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.udp_port = udp_port + + # Print available ports + ports = self.midi_in.get_ports() + print(f"Available MIDI ports: {ports}") + + # Find and open MIDI port + for i, port in enumerate(ports): + if midi_port in port: + print(f"Opening MIDI port {i}: {port}") + self.midi_in.open_port(i) + break + else: + print(f"Warning: Could not find port containing '{midi_port}'") + + self.midi_in.set_callback(self._midi_callback) + + def _midi_callback(self, message, timestamp): + try: + print(f"Received MIDI message: {message[0]}") + midi_data = struct.pack(f'B' * len(message[0]), *message[0]) + self.socket.sendto(midi_data, ('localhost', self.udp_port)) + print(f"Sent {len(midi_data)} bytes to localhost:{self.udp_port}") + except Exception as e: + print(f"Error in callback: {e}") + + def start(self): + print(f"Routing MIDI messages through SSH tunnel on port {self.udp_port}...") + try: + while True: + time.sleep(0.1) + except KeyboardInterrupt: + self.stop() + + def stop(self): + print("Shutting down...") + self.midi_in.close_port() + self.socket.close() + +def parse_args(): + parser = argparse.ArgumentParser(description='MIDI to UDP router') + parser.add_argument('-midi_p', type=str, default="14:0", + help='MIDI port identifier (default: 14:0)') + parser.add_argument('-udp_p', type=int, default=5004, + help='UDP port for forwarding (default: 5004)') + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + router = MIDIRouter(midi_port=args.midi_p, udp_port=args.udp_p) + router.start() diff --git a/example-prompts/classical.mid b/example-prompts/classical.mid new file mode 100644 index 00000000..4d4db899 Binary files /dev/null and b/example-prompts/classical.mid differ diff --git a/example-prompts/nocturne.mid b/example-prompts/nocturne.mid new file mode 100644 index 00000000..00d41ff9 Binary files /dev/null and b/example-prompts/nocturne.mid differ diff --git a/example-prompts/pokey_jazz.mid b/example-prompts/pokey_jazz.mid new file mode 100644 index 00000000..73c4f216 Binary files /dev/null and b/example-prompts/pokey_jazz.mid differ diff --git a/example-prompts/smooth_jazz.mid b/example-prompts/smooth_jazz.mid new file mode 100644 index 00000000..73f31a92 Binary files /dev/null and b/example-prompts/smooth_jazz.mid differ diff --git a/example-prompts/yesterday.mid b/example-prompts/yesterday.mid new file mode 100644 index 00000000..513b7c76 Binary files /dev/null and b/example-prompts/yesterday.mid differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..199cfb8d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,41 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "aria" +version = "0.0.1" +description = "" +authors = [{name = "Louis Bradshaw", email = "loua19@outlook.com"}] +requires-python = ">=3.11" + +dependencies = [ + "ariautils @ git+https://github.com/EleutherAI/aria-utils.git", + "torch>=2.3", + "mlx", + "safetensors", + "jsonlines", + "tqdm", +] + +[project.optional-dependencies] +dev = ["black"] +train = ["accelerate"] +eval = ["transformers", "torchaudio", "mido"] +demo = ["python-rtmidi"] +all = ["black", "accelerate", "transformers", "torchaudio", "mido", "python-rtmidi"] + +[tool.black] +line-length = 80 +target-version = ["py311"] +include = '\.pyi?$' + +[project.scripts] +aria = "aria.run:main" + +[tool.setuptools.packages.find] +where = ["."] +include = ["aria", "aria.*"] + +[tool.setuptools.package-data] +aria = ["../config/*.json", "../config/models/*.json"] diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 28a28f79..00000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,2 +0,0 @@ -flake8 -black diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 5d6ef6dd..00000000 --- a/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -ariautils @ git+https://github.com/EleutherAI/aria-utils.git -torch >= 2.3 -accelerate -jsonlines -tqdm -safetensors \ No newline at end of file diff --git a/scripts/download_data.sh b/scripts/download_data.sh deleted file mode 100644 index cba01655..00000000 --- a/scripts/download_data.sh +++ /dev/null @@ -1,3 +0,0 @@ -mkdir data -gsutil cp gs://gpt-aria/train_data/train.jsonl data/train.jsonl -gsutil cp gs://gpt-aria/train_data/val.jsonl data/val.jsonl \ No newline at end of file diff --git a/scripts/midi_to_audio.py b/scripts/midi_to_audio.py deleted file mode 100644 index 8221383a..00000000 --- a/scripts/midi_to_audio.py +++ /dev/null @@ -1,20 +0,0 @@ -import os - -from aria.utils import midi_to_audio - - -def main(): - root_dir = "/Users/louis/work/data/mid/prompts/survey" - for dirpath, dirnames, filenames in os.walk(root_dir): - for filename in filenames: - if filename.endswith(".mid"): - midi_path = os.path.join(dirpath, filename) - midi_to_audio(midi_path) - - -if __name__ == "__main__": - main() - - -if __name__ == "__main__": - main() diff --git a/scripts/upload_data.sh b/scripts/upload_data.sh deleted file mode 100644 index c037d375..00000000 --- a/scripts/upload_data.sh +++ /dev/null @@ -1,2 +0,0 @@ -gsutil cp data/train.jsonl gs://gpt-aria/train_data/train.jsonl -gsutil cp data/val.jsonl gs://gpt-aria/train_data/val.jsonl \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index bc558f2c..00000000 --- a/setup.py +++ /dev/null @@ -1,25 +0,0 @@ -import os - -import pkg_resources -from setuptools import find_packages, setup - -setup( - name="aria", - py_modules=["aria"], - version="0.0.1", - description="", - author="", - packages=find_packages() + ["config"], - include_package_data=True, - entry_points={ - "console_scripts": [ - "aria=aria.run:main", - ], - }, - install_requires=[ - str(r) - for r in pkg_resources.parse_requirements( - open(os.path.join(os.path.dirname(__file__), "requirements.txt")) - ) - ], -) diff --git a/tests/test_data.py b/tests/test_data.py deleted file mode 100644 index bd477c60..00000000 --- a/tests/test_data.py +++ /dev/null @@ -1,305 +0,0 @@ -import unittest -import os -import shutil -import logging - -from aria import tokenizer -from aria.config import load_config -from aria.data import datasets -from aria.data.datasets import _noise_midi_dict -from ariautils.midi import MidiDict - -logger = logging.getLogger(__name__) -if not os.path.isdir("tests/test_results"): - os.makedirs("tests/test_results") - - -def setup_logger(): - logger = logging.getLogger(__name__) - for h in logger.handlers[:]: - logger.removeHandler(h) - logger.propagate = False - logger.setLevel(logging.INFO) - formatter = logging.Formatter( - "[%(asctime)s] tests.test_data: [%(levelname)s] %(message)s" - ) - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - ch.setFormatter(formatter) - logger.addHandler(ch) - - -def get_short_seq(): - return [ - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - ("prefix", "composer", "bach"), - "", - ("piano", 62, 50), - ("dur", 50), - ("wait", 100), - ("drum", 50), - ("piano", 64, 70), - ("dur", 100), - ("wait", 100), - "", - ] - - -class TestMidiDict(unittest.TestCase): - def test_resolve_pedal(self): - midi_dict = MidiDict.from_midi("tests/test_data/maestro.mid") - midi_dict.resolve_pedal() - self.assertListEqual(midi_dict.pedal_msgs, []) - mid = midi_dict.to_midi() - mid.save("tests/test_results/maestro_npedal.mid") - - -class TestMidiDataset(unittest.TestCase): - def test_build(self): - dataset = datasets.MidiDataset.build( - dir="tests/test_data", - recur=False, - ) - - self.assertEqual(len(dataset), 7) - self.assertEqual(type(dataset[0]), MidiDict) - - def test_save_load(self): - dataset = datasets.MidiDataset.build( - dir="tests/test_data", - recur=False, - ) - dataset.save("tests/test_results/mididict_dataset.jsonl") - - dataset_reloaded = datasets.MidiDataset.load( - "tests/test_results/mididict_dataset.jsonl" - ) - self.assertEqual(len(dataset_reloaded), 7) - self.assertEqual(type(dataset[0]), type(dataset_reloaded[0])) - - def test_build_to_file(self): - datasets.MidiDataset.build_to_file( - dir="tests/test_data", - save_path="tests/test_results/mididict_dataset_direct.jsonl", - recur=False, - overwrite=True, - ) - - dataset_reloaded = datasets.MidiDataset.load( - load_path="tests/test_results/mididict_dataset_direct.jsonl", - ) - self.assertEqual(len(dataset_reloaded), 7) - self.assertEqual(type(dataset_reloaded[0]), MidiDict) - - def test_split_from_file(self): - datasets.MidiDataset.build_to_file( - dir="tests/test_data", - save_path="tests/test_results/mididict_dataset.jsonl", - recur=False, - overwrite=True, - ) - - datasets.MidiDataset.split_from_file( - load_path="tests/test_results/mididict_dataset.jsonl", - train_val_ratio=0.7, - repeatable=True, - overwrite=True, - ) - - self.assertTrue( - os.path.isfile("tests/test_results/mididict_dataset_train.jsonl") - ) - self.assertTrue( - os.path.isfile("tests/test_results/mididict_dataset_val.jsonl") - ) - - def test_data_hash(self): - mid_1 = MidiDict.from_midi("tests/test_data/pop.mid") - mid_2 = MidiDict.from_midi("tests/test_data/pop_copy.mid") - - self.assertEqual(mid_1.calculate_hash(), mid_2.calculate_hash()) - - def test_concat(self): - if ( - os.path.exists("tests/test_results/mididict_dataset_train.jsonl") - and os.path.exists("tests/test_results/mididict_dataset_val.jsonl") - and os.path.exists("tests/test_results/mididict_dataset.jsonl") - ): - datasets.MidiDataset.combine_datasets_from_file( - "tests/test_results/mididict_dataset_train.jsonl", - "tests/test_results/mididict_dataset_val.jsonl", - "tests/test_results/mididict_dataset.jsonl", - output_path="tests/test_results/mididict_dataset_concat.jsonl", - ) - - self.assertAlmostEqual( - len( - datasets.MidiDataset.load( - "tests/test_results/mididict_dataset_concat.jsonl" - ) - ), - len( - datasets.MidiDataset.load( - "tests/test_results/mididict_dataset.jsonl" - ) - ), - ) - - -class TestPretrainingDataset(unittest.TestCase): - def test_build(self): - MAX_SEQ_LEN = 4096 - tknzr = tokenizer.AbsTokenizer(return_tensors=False) - mididict_dataset = datasets.MidiDataset.build( - dir="tests/test_data", - recur=False, - ) - mididict_dataset.save("tests/test_results/mididict_dataset.jsonl") - - if os.path.exists("tests/test_results/pretrain_dataset_buff_1"): - shutil.rmtree("tests/test_results/pretrain_dataset_buff_1") - if os.path.exists("tests/test_results/pretrain_dataset_buff_2"): - shutil.rmtree("tests/test_results/pretrain_dataset_buff_2") - - dataset_from_file = datasets.PretrainingDataset.build( - tokenizer=tknzr, - save_dir="tests/test_results/pretrain_dataset_buff_1", - max_seq_len=MAX_SEQ_LEN, - num_epochs=3, - midi_dataset_path="tests/test_results/mididict_dataset.jsonl", - ) - dataset_from_mdset = datasets.PretrainingDataset.build( - tokenizer=tknzr, - save_dir="tests/test_results/pretrain_dataset_buff_2", - max_seq_len=MAX_SEQ_LEN, - num_epochs=2, - midi_dataset=mididict_dataset, - ) - - def test_multiple_paths(self): - MAX_SEQ_LEN = 4096 - tknzr = tokenizer.AbsTokenizer(return_tensors=False) - mididict_dataset = datasets.MidiDataset.build( - dir="tests/test_data", - recur=False, - ) - mididict_dataset.save("tests/test_results/mididict_dataset_1.jsonl") - - if os.path.exists("tests/test_results/pretrain_dataset_buff_1"): - shutil.rmtree("tests/test_results/pretrain_dataset_buff_1") - if os.path.exists("tests/test_results/pretrain_dataset_buff_2"): - shutil.rmtree("tests/test_results/pretrain_dataset_buff_2") - - datasets.PretrainingDataset.build( - tokenizer=tknzr, - save_dir="tests/test_results/pretrain_dataset_buff_1", - max_seq_len=MAX_SEQ_LEN, - num_epochs=3, - midi_dataset_path="tests/test_results/mididict_dataset.jsonl", - ) - datasets.PretrainingDataset.build( - tokenizer=tknzr, - save_dir="tests/test_results/pretrain_dataset_buff_2", - max_seq_len=MAX_SEQ_LEN, - num_epochs=5, - midi_dataset_path="tests/test_results/mididict_dataset.jsonl", - ) - - dataset = datasets.PretrainingDataset( - dir_paths=[ - "tests/test_results/pretrain_dataset_buff_1", - "tests/test_results/pretrain_dataset_buff_2", - ], - tokenizer=tknzr, - ) - - for epoch in range(11): - for idx, _ in enumerate(dataset): - pass - - print("-------------") - dataset.init_epoch() - - def test_aug(self): - MAX_SEQ_LEN = 512 - tknzr = tokenizer.AbsTokenizer(return_tensors=False) - mididict_dataset = datasets.MidiDataset.build( - dir="tests/test_data", - recur=False, - ) - if os.path.exists("tests/test_results/pretrain_dataset_buff"): - shutil.rmtree("tests/test_results/pretrain_dataset_buff") - pretrain_dataset = datasets.PretrainingDataset.build( - tokenizer=tknzr, - save_dir="tests/test_results/pretrain_dataset_buff", - max_seq_len=MAX_SEQ_LEN, - num_epochs=1, - midi_dataset=mididict_dataset, - ) - pretrain_dataset.set_transform(tknzr.export_data_aug()) - for idx, seq in enumerate(tknzr.decode(pretrain_dataset[0][0])): - for _idx, tok in enumerate(seq): - if tok == tknzr.unk_tok: - logger.warning(f"unk_tok seen at seq={idx}, idx={_idx}") - - logger.info(f"data_aug_1: {tknzr.decode(pretrain_dataset[0][0][:50])}") - logger.info(f"data_aug_2: {tknzr.decode(pretrain_dataset[0][0][:50])}") - - -class TestFinetuningDataset(unittest.TestCase): - def test_noise(self): - config = load_config()["data"]["finetuning"]["noising"] - midi_dict = MidiDict.from_midi("tests/test_data/clean/1.mid") - noisy_midi_dict = _noise_midi_dict(midi_dict, config) - noisy_midi = noisy_midi_dict.to_midi() - noisy_midi.save("tests/test_results/noisy.mid") - - def test_build(self): - MAX_SEQ_LEN = 4096 - tknzr = tokenizer.SeparatedAbsTokenizer(return_tensors=False) - clean_mididict_dataset = datasets.MidiDataset.build( - dir="tests/test_data/clean", - recur=True, - shuffle=False, - ) - noisy_mididict_dataset = datasets.MidiDataset.build( - dir="tests/test_data/noisy", - recur=True, - shuffle=False, - ) - if os.path.exists("tests/test_results/clean.jsonl"): - os.remove("tests/test_results/clean.jsonl") - if os.path.exists("tests/test_results/noisy.jsonl"): - os.remove("tests/test_results/noisy.jsonl") - clean_mididict_dataset.save("tests/test_results/clean.jsonl") - noisy_mididict_dataset.save("tests/test_results/noisy.jsonl") - - if os.path.exists("tests/test_results/comb"): - shutil.rmtree("tests/test_results/comb") - - finetuning_dataset = datasets.FinetuningDataset.build( - tokenizer=tknzr, - save_dir="tests/test_results/comb", - max_seq_len=MAX_SEQ_LEN, - num_epochs=2, - clean_dataset_path="tests/test_results/clean.jsonl", - noisy_dataset_paths=["tests/test_results/noisy.jsonl"], - ) - - finetuning_dataset.init_epoch(0) - for seq, tgt, mask in finetuning_dataset: - tokenized_seq = tknzr.decode(seq) - if ( - tknzr.inst_start_tok in tokenized_seq - and tknzr.bos_tok not in tokenized_seq - ): - detokenized_midi_dict = tknzr.detokenize(tokenized_seq) - res = detokenized_midi_dict.to_midi() - res.save(f"tests/test_results/comb.mid") - break - - -setup_logger() -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_data/arabesque.mid b/tests/test_data/arabesque.mid deleted file mode 100644 index 38a0f858..00000000 Binary files a/tests/test_data/arabesque.mid and /dev/null differ diff --git a/tests/test_data/bach.mid b/tests/test_data/bach.mid deleted file mode 100644 index 329d4599..00000000 Binary files a/tests/test_data/bach.mid and /dev/null differ diff --git a/tests/test_data/basic.mid b/tests/test_data/basic.mid deleted file mode 100644 index c44fe36a..00000000 Binary files a/tests/test_data/basic.mid and /dev/null differ diff --git a/tests/test_data/beethoven_moonlight.mid b/tests/test_data/beethoven_moonlight.mid deleted file mode 100644 index 1c1ec5d2..00000000 Binary files a/tests/test_data/beethoven_moonlight.mid and /dev/null differ diff --git a/tests/test_data/beethoven_sonata.mid b/tests/test_data/beethoven_sonata.mid deleted file mode 100644 index 923a4d56..00000000 Binary files a/tests/test_data/beethoven_sonata.mid and /dev/null differ diff --git a/tests/test_data/clean/1.mid b/tests/test_data/clean/1.mid deleted file mode 100644 index 5a221b28..00000000 Binary files a/tests/test_data/clean/1.mid and /dev/null differ diff --git a/tests/test_data/clean/2.mid b/tests/test_data/clean/2.mid deleted file mode 100644 index 2b2fa8a4..00000000 Binary files a/tests/test_data/clean/2.mid and /dev/null differ diff --git a/tests/test_data/expressive.mid b/tests/test_data/expressive.mid deleted file mode 100644 index 40d9e846..00000000 Binary files a/tests/test_data/expressive.mid and /dev/null differ diff --git a/tests/test_data/noisy/1.mid b/tests/test_data/noisy/1.mid deleted file mode 100644 index 522cb3b4..00000000 Binary files a/tests/test_data/noisy/1.mid and /dev/null differ diff --git a/tests/test_data/noisy/2.mid b/tests/test_data/noisy/2.mid deleted file mode 100644 index 621fbe37..00000000 Binary files a/tests/test_data/noisy/2.mid and /dev/null differ diff --git a/tests/test_data/pop.mid b/tests/test_data/pop.mid deleted file mode 100644 index be83c69d..00000000 Binary files a/tests/test_data/pop.mid and /dev/null differ diff --git a/tests/test_data/pop_copy.mid b/tests/test_data/pop_copy.mid deleted file mode 100644 index be83c69d..00000000 Binary files a/tests/test_data/pop_copy.mid and /dev/null differ diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py deleted file mode 100644 index f110a7f0..00000000 --- a/tests/test_tokenizers.py +++ /dev/null @@ -1,535 +0,0 @@ -import unittest -import logging -import os -import time - -from typing import Callable - -from aria import tokenizer -from aria.config import load_config -from ariautils.midi import MidiDict -from aria.data.datasets import _get_combined_mididict, _noise_midi_dict -from aria.utils import midi_to_audio - - -if not os.path.isdir("tests/test_results"): - os.makedirs("tests/test_results") - - -# TODO: Implement with tokenizer functions -def get_short_seq_abs(tknzr: tokenizer.AbsTokenizer): - return [ - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - "", - ("piano", 62, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(0)), - ("dur", tknzr._quantize_dur(50)), - ("drum", 50), - ("onset", tknzr._quantize_onset(100)), - ("piano", 64, tknzr._quantize_velocity(75)), - ("onset", tknzr._quantize_onset(100)), - ("dur", tknzr._quantize_dur(5000)), - "", - "", - "", - ("piano", 65, tknzr._quantize_velocity(75)), - ("onset", tknzr._quantize_onset(170)), - ("dur", tknzr._quantize_dur(100)), - "", - ("piano", 60, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(60)), - "", - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(70)), - ("drum", 50), - ("onset", tknzr._quantize_onset(270)), - "", - ("piano", 80, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(80)), - "", - ] - - -def get_concat_seq_abs(tknzr: tokenizer.AbsTokenizer): - return [ - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(60)), - "", - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(70)), - ("drum", 50), - ("onset", tknzr._quantize_onset(270)), - "", - ("piano", 80, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(80)), - "", - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - "", - ("piano", 62, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(0)), - ("dur", tknzr._quantize_dur(50)), - ("drum", 50), - ("onset", tknzr._quantize_onset(100)), - ("piano", 64, tknzr._quantize_velocity(75)), - ("onset", tknzr._quantize_onset(100)), - ("dur", tknzr._quantize_dur(5000)), - "", - "", - "", - ("piano", 65, tknzr._quantize_velocity(75)), - ("onset", tknzr._quantize_onset(170)), - ("dur", tknzr._quantize_dur(100)), - "", - ("piano", 60, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(60)), - "", - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(70)), - ("drum", 50), - ("onset", tknzr._quantize_onset(270)), - "", - ("piano", 80, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(270)), - ("dur", tknzr._quantize_dur(80)), - "", - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - "", - ("piano", 62, tknzr._quantize_velocity(45)), - ("onset", tknzr._quantize_onset(0)), - ("dur", tknzr._quantize_dur(50)), - ("drum", 50), - ("onset", tknzr._quantize_onset(100)), - ("piano", 64, tknzr._quantize_velocity(75)), - ("onset", tknzr._quantize_onset(100)), - ("dur", tknzr._quantize_dur(5000)), - "", - "", - ] - - -def get_short_seq_rel(tknzr: tokenizer.RelTokenizer): - return [ - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - ("prefix", "composer", "bach"), - "", - ("piano", 62, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(50)), - ("wait", tknzr._quantize_time(100)), - ("drum", 50), - ("piano", 64, tknzr._quantize_velocity(70)), - ("dur", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(100)), - ("piano", 65, tknzr._quantize_velocity(70)), - ("dur", tknzr._quantize_time(100)), - ("wait", tknzr._quantize_time(100)), - ("piano", 60, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(60)), - ("piano", 70, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(70)), - ("drum", 50), - ("piano", 80, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(80)), - ("wait", tknzr._quantize_time(100)), - "", - ] - - -def get_concat_seq_rel(tknzr: tokenizer.RelTokenizer): - return [ - ("dur", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(100)), - ("piano", 65, tknzr._quantize_velocity(70)), - ("dur", tknzr._quantize_time(100)), - ("wait", tknzr._quantize_time(100)), - ("piano", 60, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(60)), - ("piano", 70, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(70)), - ("drum", 50), - ("piano", 80, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(80)), - ("wait", tknzr._quantize_time(100)), - "", - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - ("prefix", "composer", "bach"), - "", - ("piano", 62, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(50)), - ("wait", tknzr._quantize_time(100)), - ("drum", tknzr._quantize_time(50)), - ("piano", 64, tknzr._quantize_velocity(70)), - ("dur", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(1000000)), - ("wait", tknzr._quantize_time(100)), - ("piano", 65, tknzr._quantize_velocity(70)), - ("dur", tknzr._quantize_time(100)), - ("wait", tknzr._quantize_time(100)), - ("piano", 60, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(60)), - ("piano", 70, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(70)), - ("drum", 50), - ("piano", 80, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(80)), - ("wait", tknzr._quantize_time(100)), - "", - ("prefix", "instrument", "piano"), - ("prefix", "instrument", "drum"), - ("prefix", "composer", "bach"), - "", - ("piano", 62, tknzr._quantize_velocity(50)), - ("dur", tknzr._quantize_time(50)), - ("wait", tknzr._quantize_time(100)), - ("drum", tknzr._quantize_time(50)), - ("piano", 64, tknzr._quantize_velocity(70)), - ] - - -class TestAbsTokenizer(unittest.TestCase): - def test_tokenize_detokenize_mididict(self): - def tokenize_detokenize(file_name: str): - mid_path = f"tests/test_data/{file_name}" - midi_dict = MidiDict.from_midi(mid_path=mid_path) - tokenized_seq = tknzr.tokenize(midi_dict) - detokenized_midi_dict = tknzr.detokenize(tokenized_seq) - res = detokenized_midi_dict.to_midi() - res.save(f"tests/test_results/{file_name}") - - tknzr = tokenizer.AbsTokenizer(return_tensors=False) - tokenize_detokenize("basic.mid") - tokenize_detokenize("arabesque.mid") - tokenize_detokenize("beethoven_sonata.mid") - tokenize_detokenize("bach.mid") - tokenize_detokenize("expressive.mid") - tokenize_detokenize("pop.mid") - tokenize_detokenize("beethoven_moonlight.mid") - tokenize_detokenize("maestro.mid") - - def test_aug(self): - def tokenize_aug_detokenize( - file_name: str, - aug_fn: Callable, - aug_name: str, - audio=False, - ): - mid_path = f"tests/test_data/{file_name}" - midi_dict = MidiDict.from_midi(mid_path=mid_path) - tokenized_seq = tknzr.tokenize(midi_dict) - tokenized_seq_aug = aug_fn(tokenized_seq) - detokenized_midi_dict = tknzr.detokenize(tokenized_seq_aug) - res = detokenized_midi_dict.to_midi() - save_path = f"tests/test_results/abs_{aug_name}_{file_name}" - res.save(save_path) - if audio is True: - midi_to_audio(save_path) - - tknzr = tokenizer.AbsTokenizer(return_tensors=False) - seq = get_short_seq_abs(tknzr) - seq_concat = get_concat_seq_abs(tknzr) - pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) - velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) - tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.5, mixup=True) - - # Pitch augmentation - seq_pitch_augmented = pitch_aug_fn(get_short_seq_abs(tknzr)) - logging.info(f"pitch_aug_fn:\n{seq} ->\n\n{seq_pitch_augmented}\n") - tokenize_aug_detokenize("basic.mid", pitch_aug_fn, "pitch") - tokenize_aug_detokenize("arabesque.mid", pitch_aug_fn, "pitch") - tokenize_aug_detokenize("beethoven_sonata.mid", pitch_aug_fn, "pitch") - tokenize_aug_detokenize("bach.mid", pitch_aug_fn, "pitch") - tokenize_aug_detokenize("expressive.mid", pitch_aug_fn, "pitch") - tokenize_aug_detokenize("pop.mid", pitch_aug_fn, "pitch") - tokenize_aug_detokenize( - "beethoven_moonlight.mid", pitch_aug_fn, "pitch" - ) - - # Velocity augmentation - seq_velocity_augmented = velocity_aug_fn(get_short_seq_abs(tknzr)) - logging.info( - f"velocity_aug_fn:\n{seq} ->\n\n{seq_velocity_augmented}\n" - ) - tokenize_aug_detokenize("basic.mid", velocity_aug_fn, "velocity") - tokenize_aug_detokenize("arabesque.mid", velocity_aug_fn, "velocity") - tokenize_aug_detokenize( - "beethoven_sonata.mid", velocity_aug_fn, "velocity" - ) - tokenize_aug_detokenize("bach.mid", velocity_aug_fn, "velocity") - tokenize_aug_detokenize("expressive.mid", velocity_aug_fn, "velocity") - tokenize_aug_detokenize("pop.mid", velocity_aug_fn, "velocity") - tokenize_aug_detokenize( - "beethoven_moonlight.mid", velocity_aug_fn, "velocity" - ) - - # Tempo augmentation - seq_tempo_augmented = tempo_aug_fn(get_short_seq_abs(tknzr)) - logging.info(f"tempo_aug_fn:\n{seq} ->\n\n{seq_tempo_augmented}\n") - - seq_concat_tempo_augmented = tempo_aug_fn(get_concat_seq_abs(tknzr)) - logging.info( - f"tempo_aug_fn:\n{seq_concat} ->\n\n{seq_concat_tempo_augmented}\n" - ) - - tokenize_aug_detokenize("basic.mid", tempo_aug_fn, "tempo") - tokenize_aug_detokenize("arabesque.mid", tempo_aug_fn, "tempo") - tokenize_aug_detokenize("beethoven_sonata.mid", tempo_aug_fn, "tempo") - tokenize_aug_detokenize("bach.mid", tempo_aug_fn, "tempo") - tokenize_aug_detokenize("expressive.mid", tempo_aug_fn, "tempo") - tokenize_aug_detokenize("pop.mid", tempo_aug_fn, "tempo") - tokenize_aug_detokenize( - "beethoven_moonlight.mid", tempo_aug_fn, "tempo" - ) - - def test_aug_time(self): - tknzr = tokenizer.AbsTokenizer() - mid_dict = MidiDict.from_midi("tests/test_data/beethoven_sonata.mid") - tokenized_seq = tknzr.tokenize(mid_dict)[:4096] - pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) - velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) - tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.5, mixup=True) - - # Pitch augmentation - t_start = time.perf_counter() - pitch_aug_fn(tokenized_seq) - t_pitch_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"pitch_aug_fn took {int(t_pitch_aug)}ms") - self.assertLessEqual(t_pitch_aug, 50) - - # Velocity augmentation - t_start = time.perf_counter() - velocity_aug_fn(tokenized_seq) - t_vel_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"velocity_aug_fn took {int(t_vel_aug)}ms") - self.assertLessEqual(t_vel_aug, 50) - - # Tempo augmentation - t_start = time.perf_counter() - tempo_aug_fn(tokenized_seq) - t_tempo_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"tempo_aug_fn took {int(t_tempo_aug)}ms") - self.assertLessEqual(t_tempo_aug, 50) - - def test_no_unk_token(self): - def _test_no_unk_token(file_name: str): - mid_path = f"tests/test_data/{file_name}" - midi_dict = MidiDict.from_midi(mid_path=mid_path) - seq = tknzr.tokenize(midi_dict) - enc_dec_seq = tknzr.decode(tknzr.encode(seq)) - for tok in enc_dec_seq: - self.assertTrue(tok != tknzr.unk_tok) - - tknzr = tokenizer.AbsTokenizer() - _test_no_unk_token("basic.mid") - _test_no_unk_token("arabesque.mid") - _test_no_unk_token("bach.mid") - _test_no_unk_token("expressive.mid") - _test_no_unk_token("pop.mid") - _test_no_unk_token("beethoven_moonlight.mid") - - -# TODO: This example is not working, I'm pretty sure the issue is in _get_combined_mididict somewhere -# Fix this!! -class TestSeparatedTokenizer(unittest.TestCase): - def test_tokenize_detokenize_mididict(self): - def _find_inst_onsets(_seq: list): - curr_time_ms = 0 - time_toks = 0 - for tok in _seq: - if tok == "": - time_toks += 1 - elif isinstance(tok, tuple) and tok[0] == "onset": - curr_time_ms = 5000 * time_toks + tok[1] - elif tok == "": - print("Seen at", curr_time_ms) - - tknzr = tokenizer.SeparatedAbsTokenizer() - - clean_midi_dict = MidiDict.from_midi( - mid_path="/mnt/ssd1/data/mp3/raw/maestro-mp3/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi" - ) - noisy_midi_dict = MidiDict.from_midi( - mid_path="/mnt/ssd1/data/mp3/raw/maestro-mp3/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi" - # mid_path="/mnt/ssd1/amt/transcribed_data/noisy_maestro/small-long-e7/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.mid" - ) - - noisy_midi_dict = _noise_midi_dict( - noisy_midi_dict, load_config()["data"]["finetuning"]["noising"] - ) - - clean_mid = clean_midi_dict.to_midi() - clean_mid.save(f"tests/test_results/combined_clean.mid") - noisy_mid = noisy_midi_dict.to_midi() - noisy_mid.save(f"tests/test_results/combined_noisy.mid") - - comb_midi_dict = _get_combined_mididict( - clean_midi_dict, - noisy_midi_dict, - min_noisy_ms=10000, - max_noisy_ms=25000, - min_clean_ms=30000, - max_clean_ms=60000, - ) - - comb_midi = comb_midi_dict.to_midi() - comb_midi.save(f"tests/test_results/combined_raw.mid") - tokenized_seq = tknzr.tokenize(comb_midi_dict) - detokenized_midi_dict = tknzr.detokenize(tokenized_seq) - res = detokenized_midi_dict.to_midi() - res.save(f"tests/test_results/combined.mid") - - for idx, sub_seq in enumerate(tknzr.split(tokenized_seq, 4096)): - if idx == 3: - _find_inst_onsets(sub_seq) - print(idx) - print(sub_seq) - detokenized_midi_dict = tknzr.detokenize(sub_seq) - res = detokenized_midi_dict.to_midi() - res.save(f"tests/test_results/combined{idx}.mid") - - -class TestRelTokenizer(unittest.TestCase): - def test_tokenize_detokenize_mididict(self): - def tokenize_detokenize(file_name: str): - mid_path = f"tests/test_data/{file_name}" - midi_dict = MidiDict.from_midi(mid_path=mid_path) - tokenized_seq = tknzr.tokenize(midi_dict) - detokenized_midi_dict = tknzr.detokenize(tokenized_seq) - res = detokenized_midi_dict.to_midi() - res.save(f"tests/test_results/{file_name}") - - tknzr = tokenizer.RelTokenizer(return_tensors=False) - - tokenize_detokenize("basic.mid") - tokenize_detokenize("arabesque.mid") - tokenize_detokenize("beethoven_sonata.mid") - tokenize_detokenize("bach.mid") - tokenize_detokenize("expressive.mid") - tokenize_detokenize("pop.mid") - tokenize_detokenize("beethoven_moonlight.mid") - - def test_aug(self): - tknzr = tokenizer.RelTokenizer(return_tensors=False) - seq = get_short_seq_rel(tknzr) - seq_concat = get_concat_seq_rel(tknzr) - pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) - velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) - tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.8) - chord_mixup_fn = tknzr.export_chord_mixup() - - # Pitch augmentation - seq_pitch_augmented = pitch_aug_fn(get_short_seq_rel(tknzr)) - logging.info(f"pitch_aug_fn:\n{seq} ->\n\n{seq_pitch_augmented}\n") - self.assertEqual( - seq_pitch_augmented[4][1] - seq[4][1], - seq_pitch_augmented[8][1] - seq[8][1], - ) - - # Velocity augmentation - seq_velocity_augmented = velocity_aug_fn(get_short_seq_rel(tknzr)) - logging.info( - f"velocity_aug_fn:\n{seq} ->\n\n{seq_velocity_augmented}\n" - ) - self.assertEqual( - seq_velocity_augmented[4][2] - seq[4][2], - seq_velocity_augmented[8][2] - seq[8][2], - ) - - # Tempo augmentation - seq_tempo_augmented = tempo_aug_fn(get_short_seq_rel(tknzr)) - logging.info(f"tempo_aug_fn:\n{seq} ->\n\n{seq_tempo_augmented}\n") - - seq_concat_tempo_augmented = tempo_aug_fn(get_concat_seq_rel(tknzr)) - logging.info( - f"tempo_aug_fn:\n{seq_concat} ->\n\n{seq_concat_tempo_augmented}\n" - ) - - # Chord mix-up augmentation - seq_mixup_augmented = chord_mixup_fn(get_short_seq_rel(tknzr)) - logging.info(f"chord_mixup_fn:\n{seq} ->\n\n{seq_mixup_augmented}\n") - - seq_concat_tempo_augmented = chord_mixup_fn(get_concat_seq_rel(tknzr)) - logging.info( - f"chord_mixup_fn:\n{seq_concat} ->\n\n{seq_concat_tempo_augmented}\n" - ) - - def test_aug_time(self): - tknzr = tokenizer.RelTokenizer() - mid_dict = MidiDict.from_midi("tests/test_data/beethoven_sonata.mid") - tokenized_seq = tknzr.tokenize(mid_dict)[:4096] - - pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) - velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) - tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.5) - chord_mixup_fn = tknzr.export_chord_mixup() - - # Pitch augmentation - t_start = time.perf_counter() - pitch_aug_fn(tokenized_seq) - t_pitch_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"pitch_aug_fn took {int(t_pitch_aug)}ms") - self.assertLessEqual(t_pitch_aug, 50) - - # Velocity augmentation - t_start = time.perf_counter() - velocity_aug_fn(tokenized_seq) - t_vel_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"velocity_aug_fn took {int(t_vel_aug)}ms") - self.assertLessEqual(t_vel_aug, 50) - - # Tempo augmentation - t_start = time.perf_counter() - tempo_aug_fn(tokenized_seq) - t_tempo_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"tempo_aug_fn took {int(t_tempo_aug)}ms") - self.assertLessEqual(t_tempo_aug, 50) - - # Chord mixup augmentation - t_start = time.perf_counter() - chord_mixup_fn(tokenized_seq) - t_mixup_aug = (time.perf_counter() - t_start) * 1e3 - logging.info(f"mixup_aug_fn took {int(t_mixup_aug)}ms") - self.assertLessEqual(t_mixup_aug, 50) - - def test_encode_decode(self): - tknzr = tokenizer.RelTokenizer(return_tensors=True) - seq = get_short_seq_rel(tknzr) - enc_dec_seq = tknzr.decode(tknzr.encode(seq)) - for x, y in zip(seq, enc_dec_seq): - self.assertEqual(x, y) - - tknzr = tokenizer.RelTokenizer(return_tensors=False) - seq = get_short_seq_rel(tknzr) - enc_dec_seq = tknzr.decode(tknzr.encode(seq)) - for x, y in zip(seq, enc_dec_seq): - self.assertEqual(x, y) - - def test_no_unk_token(self): - tknzr = tokenizer.RelTokenizer() - seq = get_short_seq_rel(tknzr) - enc_dec_seq = tknzr.decode(tknzr.encode(seq)) - for tok in enc_dec_seq: - self.assertTrue(tok != tknzr.unk_tok) - - -if __name__ == "__main__": - if os.path.isdir("tests/test_results") is False: - os.mkdir("tests/test_results") - - logging.basicConfig(level=logging.INFO) - unittest.main()