diff --git a/README.md b/README.md index c50a923..d9f497c 100644 --- a/README.md +++ b/README.md @@ -79,27 +79,40 @@ Our embedding model was trained to capture composition-level and performance-lev ## Real-time demo -In `demo/` we provide CUDA (Linux/PyTorch) and MLX (Apple Silicon) implementations of the real-time interactive piano-continuation demo showcased in our release blog post. For the demo we used an acoustic Yamaha Disklavier piano with simultaneous MIDI input and output ports connected via a standard MIDI interface. +In `demo/` we provide an MLX (Apple Silicon) implementation of the real-time interactive piano-continuation demo showcased in our release blog post. In order to use the demo, you must download the demo-specific model checkpoint which enhances the model to additionally control the sustain pedal ([direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/model-demo.safetensors?download=true)). -❗**NOTE**: Responsiveness of the real-time demo is dependent on your system configuration, e.g., GPU FLOPS and memory bandwidth. +For our demonstration, we used an acoustic Yamaha Disklavier piano with simultaneous MIDI input and output ports connected via a standard MIDI interface. We disabled the built-in Disklavier playback mode, instead manually calibrating key-velocity latency to enhance responsiveness. You may recreate this in your own environment with our acoustic calibration settings, using the following script: -A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI, ALSA) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output. +❗**NOTE**: It is vital that you use the `latency=off`/`realtime` Disklavier playback setting when using the provided configuration for `--hardware`. -Example usage (MLX): +```bash +python ./demo/demo_mlx.py \ + --checkpoint \ + --midi_in \ + --midi_out \ + --hardware ./demo/hardware/c4dm-disklavier.json \ + --midi_control_signal 67 \ + --midi_reset_control_signal 66 \ + --temp 0.9 \ + --min_p 0.03 +``` + +A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output. In this mode, you can initiate the model takeover by pressing the enter key. ```bash -MIDI_PATH="example-prompts/pokey_jazz.mid" +MIDI_PATH="./example-prompts/smooth_jazz.mid" -python demo/demo_mlx.py \ +python ./demo/demo_mlx.py \ --checkpoint \ --midi_path ${MIDI_PATH} \ - --midi_through \ - --midi_out \ - --save_path \ - --temp 0.98 \ - --min_p 0.035 + --midi_through \ + --midi_out \ + --temp 0.9 \ + --min_p 0.03 ``` +❗**NOTE**: Responsiveness of the real-time demo is dependent on your system configuration, specifically GPU memory bandwidth. + ## Evaluation We provide the specific files/splits we used for Aria-MIDI derived linear-probe and classification evaluations. These can be downloaded from HuggingFace ([direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/eval-splits.tar.gz?download=true)). Class labels are provided in `metadata.json` with the schema: diff --git a/aria/inference/__init__.py b/aria/inference/__init__.py index ceac4b4..4cb243e 100644 --- a/aria/inference/__init__.py +++ b/aria/inference/__init__.py @@ -48,11 +48,21 @@ def get_inference_prompt( for msg in midi_dict.note_msgs if midi_dict.tick_to_ms(msg["data"]["start"]) <= prompt_len_ms ] + midi_dict.pedal_msgs = [ + msg + for msg in midi_dict.pedal_msgs + if midi_dict.tick_to_ms(msg["tick"]) <= prompt_len_ms + ] + if midi_dict.pedal_msgs and midi_dict.pedal_msgs[-1]["data"] == 1: + midi_dict.pedal_msgs.pop() if len(midi_dict.note_msgs) == 0: return [("prefix", "instrument", "piano"), tokenizer.bos_tok] - seq = tokenizer.tokenize(midi_dict=midi_dict, add_dim_tok=False) - seq.remove(tokenizer.eos_tok) + seq = tokenizer.tokenize( + midi_dict=midi_dict, + add_dim_tok=False, + add_eos_tok=False, + ) return seq diff --git a/aria/inference/model_mlx.py b/aria/inference/model_mlx.py index 169b30b..51ad32e 100644 --- a/aria/inference/model_mlx.py +++ b/aria/inference/model_mlx.py @@ -84,14 +84,15 @@ def __call__( self, x: mx.array, input_pos: mx.array, + max_kv_pos: int | None, offset: int, mask: mx.array, ): assert self.kv_cache is not None, "Cache not initialized" - x += self._att_block( x=self.norm1(x), input_pos=input_pos, + max_kv_pos=max_kv_pos, offset=offset, mask=mask, ) @@ -99,15 +100,25 @@ def __call__( return x - def get_kv(self, k: mx.array, v: mx.array, input_pos: mx.array): + def get_kv( + self, + k: mx.array, + v: mx.array, + input_pos: mx.array, + max_kv_pos: int | None, + ): k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=input_pos) - return k, v + if max_kv_pos is not None: + return k[:, :, : max_kv_pos + 1, :], v[:, :, : max_kv_pos + 1, :] + else: + return k, v def _att_block( self, x: mx.array, input_pos: mx.array, + max_kv_pos: int | None, offset: int, mask: mx.array, ): @@ -124,7 +135,8 @@ def _att_block( k = apply_rotary_emb_mlx(k, offset=offset) q, k, v = map(lambda x: x.transpose(0, 2, 1, 3), (q, k, v)) - k, v = self.get_kv(k, v, input_pos=input_pos) + k, v = self.get_kv(k, v, input_pos=input_pos, max_kv_pos=max_kv_pos) + wv = mx.fast.scaled_dot_product_attention( q=q, k=k, @@ -159,6 +171,7 @@ def __init__(self, model_config: ModelConfig): TransformerBlock(model_config) for _ in range(model_config.n_layers) ] self.out_layer_norm = nn.LayerNorm(model_config.d_model) + self.kv_ctx = None def fill_condition_kv(self, emb: mx.array): assert self.causal_mask is not None, "Caches must be initialized first" @@ -177,20 +190,30 @@ def __call__( self, idxs: mx.array, input_pos: mx.array, + max_kv_pos: int, offset: int, pad_idxs: mx.array | None = None, + _debug_track_kv: bool = False, ): assert self.causal_mask is not None, "Caches must be initialized first" - mask = self.causal_mask[None, None, input_pos] + if self.kv_ctx is None: + self.kv_ctx = mx.full( + self.model_config.max_seq_len, 3 + ) # unk_tok id + + if _debug_track_kv is True: + self.kv_ctx[input_pos] = idxs + self.kv_ctx[input_pos[-1].item() + 1 :] = 3 + mask = self.causal_mask[None, None, input_pos, : max_kv_pos + 1] if pad_idxs is not None: pad_mask = mx.expand_dims(mx.expand_dims(pad_idxs, axis=1), axis=1) mask = mask & ~pad_mask x = self.tok_embeddings(idxs) for layer in self.encode_layers: - x = layer(x, input_pos, offset, mask) + x = layer(x, input_pos, max_kv_pos, offset, mask) x = self.out_layer_norm(x) @@ -217,11 +240,13 @@ def __call__( idxs: mx.array, input_pos: mx.array, offset: int, + max_kv_pos: int | None = None, pad_idxs: mx.array | None = None, ): hidden_states = self.model( idxs=idxs, input_pos=input_pos, + max_kv_pos=max_kv_pos, offset=offset, pad_idxs=pad_idxs, ) @@ -235,6 +260,25 @@ def fill_condition_kv(self, cond_emb: mx.array): adapted_emb = self.embedding_adapter(cond_emb) self.model.fill_condition_kv(emb=adapted_emb) + def reset_kv_ctx(self): + self.model.kv_ctx = None + + def get_kv_ctx(self): + # Used for debugging kv-cache validation + _kv_ctx = self.model.kv_ctx + + match self.model.kv_ctx: + case None: + return None + case mx.array(): + _kv_ctx = self.model.kv_ctx.tolist() + if 3 in _kv_ctx: + return _kv_ctx[: _kv_ctx.index(3)] + else: + return _kv_ctx + case _: + raise ValueError + def setup_cache( self, batch_size, diff --git a/aria/inference/sample_cuda.py b/aria/inference/sample_cuda.py index 909bd8d..5a0f793 100644 --- a/aria/inference/sample_cuda.py +++ b/aria/inference/sample_cuda.py @@ -16,15 +16,6 @@ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 -def get_cfg_prompt(prompts: list): - cfg_prompts = [] - for prompt in prompts: - cfg_prompts.append(prompt) - cfg_prompts.append(prompt) - - return cfg_prompts - - @torch.inference_mode() def decode_one( model: TransformerLM, diff --git a/aria/model.py b/aria/model.py index 573f548..25b892b 100644 --- a/aria/model.py +++ b/aria/model.py @@ -178,7 +178,6 @@ def forward( seq_len=self.model_config.max_seq_len, n_elem=self.model_config.d_model // self.model_config.n_heads, base=500000, - dtype=hidden_states.dtype, ).to(src.device) freqs_cis = self.freqs_cis[: src.shape[1]] @@ -379,7 +378,6 @@ def precompute_freqs_cis( seq_len: int, n_elem: int, base: int = 500000, - dtype: torch.dtype = torch.bfloat16, ): freqs = 1.0 / ( base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) @@ -389,7 +387,7 @@ def precompute_freqs_cis( freqs_cis = torch.polar(torch.ones_like(freqs), freqs) cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) - return cache.to(dtype=dtype) + return cache @torch.jit.script @@ -397,14 +395,15 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """ In-place RoPE. Credits to Katherine Crowson: x shape (b_sz, s_len, n_head, d_head). - cos, sin shape (s_len, d_head // 2). + freqs_cis shape (s_len, d_head // 2, 2) and is float32. """ - - d = x.shape[-1] // 2 + x_float = x.float() + freqs_cis = freqs_cis.detach() + d = x_float.shape[-1] // 2 cos = freqs_cis[..., 0][None, :, None] sin = freqs_cis[..., 1][None, :, None] - x1, x2 = x[..., :d], x[..., d : d * 2] + x1, x2 = x_float[..., :d], x_float[..., d : d * 2] tmp = x1.clone() x1.mul_(cos).addcmul_(x2, sin, value=-1) x2.mul_(cos).addcmul_(tmp, sin, value=1) - return x + return x.copy_(x_float) diff --git a/aria/run.py b/aria/run.py index 71c319b..6304430 100644 --- a/aria/run.py +++ b/aria/run.py @@ -263,6 +263,7 @@ def generate(args): args.prompt_midi_path, prompt_duration_s=prompt_duration_s, ) + print(prompt) max_new_tokens = min(8096 - len(prompt), max_new_tokens) if backend == "torch_cuda": @@ -317,13 +318,13 @@ def generate(args): def _get_embedding( - embedding_model_checkpoints_path: str, + embedding_model_checkpoint_path: str, embedding_midi_path: str, ): from aria.embedding import get_global_embedding_from_midi model = _load_embedding_model( - checkpoint_path=embedding_model_checkpoints_path + checkpoint_path=embedding_model_checkpoint_path ).cpu() global_embedding = get_global_embedding_from_midi( model=model, @@ -353,7 +354,7 @@ def conditioned_generate(args): prompt_duration_s=prompt_duration_s, ) embedding = _get_embedding( - embedding_model_checkpoints_path=args.embedding_model_checkpoint_path, + embedding_model_checkpoint_path=args.embedding_model_checkpoint_path, embedding_midi_path=args.embedding_midi_path, ) max_new_tokens = min(8096 - len(prompt), max_new_tokens) diff --git a/aria/training/train.py b/aria/training/train.py index 67001a2..6521272 100644 --- a/aria/training/train.py +++ b/aria/training/train.py @@ -586,7 +586,7 @@ def resume_train( optimizer, scheduler = get_optim( model, num_epochs=epochs, - steps_per_epoch=len(train_dataloader), + steps_per_epoch=len(train_dataloader) // grad_acc_steps, ) ( @@ -731,7 +731,7 @@ def train( optimizer, scheduler = get_optim( model, num_epochs=epochs, - steps_per_epoch=len(train_dataloader), + steps_per_epoch=len(train_dataloader) // grad_acc_steps, ) ( diff --git a/config/models/medium.json b/config/models/medium.json index a1df8a6..40384d2 100644 --- a/config/models/medium.json +++ b/config/models/medium.json @@ -5,5 +5,6 @@ "ff_mult": 4, "drop_p": 0.0, "max_seq_len": 8192, + "vocab_size": 17727, "grad_checkpoint": true } diff --git a/demo/calibrate.py b/demo/calibrate.py index 74d0720..73a9120 100644 --- a/demo/calibrate.py +++ b/demo/calibrate.py @@ -6,7 +6,7 @@ import mido MIDDLE_C = 60 -C_MAJOR_CHORD = [MIDDLE_C, 64, 67, 72] # C4, E4, G4, C5 +C_MAJOR_CHORD = [MIDDLE_C - 12, 64 - 12, 67 - 12, 72 - 12] # C4, E4, G4, C5 def schedule_note_off(port: mido.ports.BaseOutput, note: int, delay: float): @@ -89,6 +89,41 @@ def note_repetition_trial( print("...loop finished.\n") +def velocity_strike_pair( + port: mido.ports.BaseOutput, + high_velocity: int, + low_velocity: int, + delay_ms: int, +): + """ + Sends a low-velocity C5, waits, then sends a high-velocity C4. + The goal is to adjust the delay until they sound simultaneous. + """ + print("Playing velocity pair (C4 high-vel, C5 low-vel)...") + delay_sec = delay_ms / 1000.0 + note_duration_sec = 1.0 # Audible duration of the notes + note_high_vel = MIDDLE_C # C4 + note_low_vel = MIDDLE_C + 1 # D4 + + # Send the low velocity note (C5) first + port.send(mido.Message("note_on", note=note_low_vel, velocity=low_velocity)) + schedule_note_off(port, note_low_vel, delay=delay_sec + note_duration_sec) + + # Wait for the specified delay + if delay_sec > 0: + time.sleep(delay_sec) + + # Send the high velocity note (C4) + port.send( + mido.Message("note_on", note=note_high_vel, velocity=high_velocity) + ) + schedule_note_off(port, note_high_vel, delay=note_duration_sec) + + # Give the user time to hear the result + time.sleep(note_duration_sec + 0.5) + print("...done.\n") + + def calibrate_output_latency( port_name: str, velocity: int, @@ -161,6 +196,52 @@ def calibrate_note_timing( print(f"\nAn error occurred: {e}") +def calibrate_velocity_latency( + port_name: str, + velocity: int, + low_velocity: int, + step_ms: int, + initial_delay_ms: int, + chord_mode: bool, +): + """ + Interactive loop to find the latency difference between velocities. + This mode uses fixed notes (C4 and C5) and ignores the --chord flag. + """ + delay_ms = initial_delay_ms + + try: + with mido.open_output(port_name) as port: + print(f"Opened MIDI output: {port_name}\n") + print( + f"High-velocity note (C4): {velocity}\n" + f"Low-velocity note (C5): {low_velocity}\n" + ) + while True: + velocity_strike_pair( + port, + high_velocity=velocity, + low_velocity=low_velocity, + delay_ms=delay_ms, + ) + print(f"Current low-velocity pre-send delay: {delay_ms} ms") + cmd = ( + input("[u]p / [d]own / [r]epeat / [q]uit: ").strip().lower() + ) + + if cmd == "u": + delay_ms += step_ms + elif cmd == "d": + delay_ms = max(0, delay_ms - step_ms) + elif cmd == "q": + break + print() + except (KeyboardInterrupt, SystemExit): + print("\nInterrupted — exiting.") + except Exception as e: + print(f"\nAn error occurred: {e}") + + def measure_input_latency(port_name: str, timeout_sec: float = 2.0): """ 3-2-1-GO countdown → you strike a key on GO. @@ -240,21 +321,21 @@ def parse_args(): "--velocity", "-v", type=int, - default=80, - help="Note-on velocity (1-127).", + default=120, + help="Note-on velocity (1-127). For 'velocity' mode, this is the HIGH velocity.", ) parent.add_argument( "--step", "-s", type=int, default=10, - help="Adjustment step in ms (latency/timing modes).", + help="Adjustment step in ms (for latency/timing/velocity modes).", ) parent.add_argument( "--chord", "-c", action="store_true", - help="Use a C-major chord instead of single note.", + help="Use a C-major chord instead of single note (ignored in 'velocity' mode).", ) sub = parser.add_subparsers(dest="command", help="Available commands.") @@ -296,7 +377,7 @@ def parse_args(): help="Initial gap between notes in ms.", ) - # ── input-latency measurement (new) ─────────────────────────────────── + # ── input-latency measurement ───────────────────────────────────────── p_in = sub.add_parser( "input", parents=[parent], @@ -311,6 +392,27 @@ def parse_args(): help="Seconds to wait for a key press before retry.", ) + # ── velocity-latency calibration (NEW) ──────────────────────────────── + p_vel = sub.add_parser( + "velocity", + parents=[parent], + help="Calibrate additional latency of low-velocity notes.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p_vel.add_argument( + "--low-velocity", + "-lv", + type=int, + default=20, + help="The low velocity to compare against (1-127).", + ) + p_vel.add_argument( + "--delay", + type=int, + default=50, + help="Initial pre-send delay for the low-velocity note in ms.", + ) + args = parser.parse_args() # global flag handler @@ -320,7 +422,7 @@ def parse_args(): if not args.command: parser.error( - "A command is required: choose 'output', 'timing', or 'input'." + "A command is required: choose 'output', 'timing', 'input', or 'velocity'." ) return args @@ -355,6 +457,16 @@ def main(): timeout_sec=args.timeout, ) + elif args.command == "velocity": + calibrate_velocity_latency( + port_name=args.port, + velocity=args.velocity, + low_velocity=args.low_velocity, + step_ms=args.step, + initial_delay_ms=args.delay, + chord_mode=args.chord, + ) + if __name__ == "__main__": main() diff --git a/demo/config.json b/demo/config.json new file mode 100644 index 0000000..c648810 --- /dev/null +++ b/demo/config.json @@ -0,0 +1,51 @@ +{ + "tokenizer": { + "abs": { + "ignore_instruments": { + "piano": false, + "chromatic": true, + "organ": false, + "guitar": false, + "bass": false, + "strings": false, + "ensemble": false, + "brass": false, + "reed": false, + "pipe": false, + "synth_lead": false, + "synth_pad": true, + "synth_effect": true, + "ethnic": true, + "percussive": true, + "sfx": true + }, + "instrument_programs": { + "piano": 0, + "chromatic": 13, + "organ": 16, + "guitar": 24, + "bass": 32, + "strings": 40, + "ensemble": 48, + "brass": 56, + "reed": 64, + "pipe": 73, + "synth_lead": 80, + "synth_pad": 88, + "synth_effect": 96, + "ethnic": 104, + "percussive": 112, + "sfx": 120 + }, + "drum_velocity": 60, + "velocity_quantization_step": 10, + "abs_time_step_ms": 5000, + "max_dur_ms": 5000, + "time_step_ms": 10, + "include_pedal": true, + "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"], + "form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"], + "genre_names": ["jazz", "classical"] + } + } +} diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index e28a56a..9e08dbf 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -4,54 +4,50 @@ import os import time import uuid -import copy import random import logging import threading import queue -import copy +import math +import sys +import pathlib +import select +import json import mido -import torch import mlx.core as mx import mlx.nn as nn -import numpy as np from ariautils.midi import MidiDict, midi_to_dict from ariautils.tokenizer import AbsTokenizer from aria.inference.model_mlx import TransformerLM from aria.model import ModelConfig from aria.config import load_model_config - -DTYPE = mx.float32 -MAX_SEQ_LEN = 2048 -PREFILL_CHUNK_SIZE_L = 128 -PREFILL_CHUNK_SIZE = 16 -RECALC_DUR_PREFILL_CHUNK_SIZE = 8 -RECALC_DUR_BUFFER_MS = 100 - -BEAM_WIDTH = 3 -TIME_TOK_WEIGHTING = -5 -FIRST_ONSET_BUFFER_MS = -100 # Controls onset timing for first generated not - -# HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS -# HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg -# HARDWARE: All messages are sent HARDWARE_OUTPUT_LATENCY_MS early - -# C4DM Disklavier: -# MIN_NOTE_DELTA_MS = 40 -# MIN_NOTE_LEN_MS = 100 -# HARDWARE_INPUT_LATENCY_MS = 50 -# HARDWARE_OUTPUT_LATENCY_MS = 120 - -# Pianoteq -MIN_NOTE_DELTA_MS = 0 -MIN_NOTE_LEN_MS = 30 -HARDWARE_INPUT_LATENCY_MS = 0 -HARDWARE_OUTPUT_LATENCY_MS = 0 - -MAX_STREAM_DELAY_MS = 50 - +from aria.run import _get_embedding + +EMBEDDING_OFFSET: int = 0 +DTYPE = mx.bfloat16 +MAX_SEQ_LEN: int = 4096 +KV_CHUNK_SIZE: int = 256 +PREFILL_CHUNK_SIZE_L: int = 128 +PREFILL_CHUNK_SIZE: int = 16 +RECALC_DUR_PREFILL_CHUNK_SIZE: int = 8 +RECALC_DUR_BUFFER_MS: int = 100 + +BEAM_WIDTH: int = 3 +TIME_TOK_WEIGHTING: int = -5 +FIRST_ONSET_BUFFER_MS: int = -200 +MAX_STREAM_DELAY_MS: int = 100 + +MIN_NOTE_DELTA_MS: int = 0 +MIN_PEDAL_DELTA_MS: int = 0 +MIN_NOTE_LENGTH_MS: int = 10 +HARDWARE_INPUT_LATENCY_MS: int = 0 +BASE_OUTPUT_LATENCY_MS: int = 0 +VELOCITY_OUTPUT_LATENCY_MS: dict[int, int] = {v: 0 for v in range(0, 127, 10)} + + +config_path = pathlib.Path(__file__).parent.resolve().joinpath("config.json") file_handler = logging.FileHandler("./demo.log", mode="w") file_handler.setLevel(logging.DEBUG) @@ -87,6 +83,120 @@ def formatTime(self, record, datefmt=None): return logger +def parse_args(): + argp = argparse.ArgumentParser() + argp.add_argument("--checkpoint", help="path to model checkpoint") + argp.add_argument("--midi_in", required=False, help="MIDI input port") + argp.add_argument("--midi_out", required=True, help="MIDI output port") + argp.add_argument( + "--midi_through", + required=False, + help="MIDI through port for received input", + ) + argp.add_argument( + "--midi_path", + required=False, + help="Use MIDI file instead of MIDI input port", + ) + argp.add_argument( + "--midi_control_signal", + type=int, + help="MIDI control change message for AI takeover", + ) + argp.add_argument( + "--midi_reset_control_signal", + type=int, + help="MIDI control change message context reset", + ) + argp.add_argument( + "--back_and_forth", + action="store_true", + help="Enable toggling between human and AI. If not set, the control signal will reset the session.", + required=False, + ) + argp.add_argument( + "--temp", + help="sampling temperature value", + type=float, + required=False, + default=0.95, + ) + argp.add_argument( + "--min_p", + help="sampling min_p value", + type=float, + required=False, + default=0.03, + ) + argp.add_argument( + "--wait_for_close", + help="wait for note-offs before generating", + action="store_true", + ) + argp.add_argument( + "--quantize", + help="apply model quantize", + action="store_true", + ) + argp.add_argument( + "--save_path", + type=str, + required=False, + help="path to save complete MIDI file", + ) + argp.add_argument( + "--hardware", + type=str, + required=False, + help="path to json file containing hardware calibration settings", + ) + argp.add_argument( + "--embedding_checkpoint", + type=str, + help="path to embedding model checkpoint for conditioned generation", + required=False, + ) + argp.add_argument( + "--embedding_midi_path", + type=str, + help="path to embedding MIDI file for conditioned generation", + required=False, + ) + argp.add_argument( + "--playback", + action="store_true", + help="playback file at midi_path through output_port", + required=False, + ) + + return argp.parse_args() + + +def set_calibration_settings(load_path: str): + with open(load_path, "r") as f: + _settings = json.load(f) + + global MIN_NOTE_DELTA_MS + global MIN_PEDAL_DELTA_MS + global MIN_NOTE_LENGTH_MS + global HARDWARE_INPUT_LATENCY_MS + global BASE_OUTPUT_LATENCY_MS + global VELOCITY_OUTPUT_LATENCY_MS + + MIN_NOTE_DELTA_MS = _settings["MIN_NOTE_DELTA_MS"] + MIN_PEDAL_DELTA_MS = _settings["MIN_PEDAL_DELTA_MS"] + MIN_NOTE_LENGTH_MS = _settings["MIN_NOTE_LENGTH_MS"] + HARDWARE_INPUT_LATENCY_MS = _settings["HARDWARE_INPUT_LATENCY_MS"] + BASE_OUTPUT_LATENCY_MS = _settings["BASE_OUTPUT_LATENCY_MS"] + VELOCITY_OUTPUT_LATENCY_MS = { + int(k): v for k, v in _settings["VELOCITY_OUTPUT_LATENCY_MS"].items() + } + + +def _get_input_latency_ms(velocity: int): + return BASE_OUTPUT_LATENCY_MS + VELOCITY_OUTPUT_LATENCY_MS[velocity] + + def get_epoch_time_ms() -> int: return round(time.time() * 1000) @@ -95,14 +205,14 @@ def prefill( model: TransformerLM, idxs: mx.array, input_pos: mx.array, - pad_idxs: mx.array | None = None, ) -> mx.array: # pad_idxs is only needed for prepended pad tokens logits = model( idxs=idxs, - input_pos=input_pos, - offset=input_pos[0], - pad_idxs=pad_idxs, + input_pos=input_pos + EMBEDDING_OFFSET, + max_kv_pos=math.ceil(input_pos[-1].item() / KV_CHUNK_SIZE) + * KV_CHUNK_SIZE, + offset=input_pos[0] + EMBEDDING_OFFSET, ) return logits @@ -112,44 +222,37 @@ def decode_one( model: TransformerLM, idxs: mx.array, input_pos: mx.array, - pad_idxs: mx.array | None = None, ) -> mx.array: - # pad_idxs is only needed for prepended pad tokens assert input_pos.shape[-1] == 1 logits = model( idxs=idxs, - input_pos=input_pos, - offset=input_pos[0], - pad_idxs=pad_idxs, + input_pos=input_pos + EMBEDDING_OFFSET, + max_kv_pos=math.ceil(input_pos[-1].item() / KV_CHUNK_SIZE) + * KV_CHUNK_SIZE, + offset=input_pos[0] + EMBEDDING_OFFSET, )[:, -1] return logits -def sample_min_p(probs: mx.array, p_base: float): - """See - https://arxiv.org/pdf/2407.01082""" - - p_max = mx.max(probs, axis=-1, keepdims=True) - p_scaled = p_base * p_max - mask = probs >= p_scaled - - masked_probs = mx.where(~mask, mx.zeros_like(probs), probs) - sum_masked_probs = mx.sum(masked_probs, axis=-1, keepdims=True) - masked_probs_normalized = masked_probs / sum_masked_probs +def sample_min_p(logits: mx.array, p_base: float): + """Min_p sampler in logit space, see - https://arxiv.org/pdf/2407.01082""" + if p_base <= 0.0: + return mx.argmax(logits, axis=-1, keepdims=True) + if p_base >= 1.0: + return mx.random.categorical(logits, num_samples=1) - # Dumb workaround for mlx not having categorical probs sampler - next_token = mx.array( - torch.multinomial( - torch.from_numpy(np.array(masked_probs_normalized)), num_samples=1 - ), - dtype=mx.int32, - ) + log_p_max = mx.max(logits, axis=-1, keepdims=True) + log_p_scaled = mx.log(p_base) + log_p_max + mask = logits >= log_p_scaled + masked_logits = mx.where(~mask, -mx.inf, logits) + next_token = mx.random.categorical(masked_logits, num_samples=1) return next_token -def _compile_prefill( +def _warmup_prefill( model: TransformerLM, logger: logging.Logger, chunk_size: int, @@ -194,7 +297,7 @@ def _compile_prefill( return model -def _compile_decode_one( +def _warmup_decode_one( model: TransformerLM, logger: logging.Logger, ): @@ -231,7 +334,7 @@ def _compile_decode_one( return model -def compile_model(model: TransformerLM): +def warmup_model(model: TransformerLM): logger = get_logger() model.eval() @@ -241,39 +344,46 @@ def compile_model(model: TransformerLM): dtype=DTYPE, ) - model = _compile_decode_one(model=model, logger=logger) + model = _warmup_decode_one(model=model, logger=logger) for chunk_size in list( { - # PREFILL_CHUNK_SIZE_L, - # PREFILL_CHUNK_SIZE, + PREFILL_CHUNK_SIZE, RECALC_DUR_PREFILL_CHUNK_SIZE, } ): - model = _compile_prefill( + model = _warmup_prefill( model=model, logger=logger, chunk_size=chunk_size ) return model -def load_model( - checkpoint_path: str, -): +def load_model(checkpoint_path: str): logger = get_logger() - tokenizer = AbsTokenizer() + tokenizer = AbsTokenizer(config_path=config_path) model_config = ModelConfig(**load_model_config("medium-emb")) model_config.set_vocab_size(tokenizer.vocab_size) + weights = mx.load(checkpoint_path) + for key, weight in weights.items(): + if weight.dtype != DTYPE: + weights[key] = weight.astype(DTYPE) + logging.info(f"Loading model weights from {checkpoint_path}") init_start_time_s = time.time() model = TransformerLM(model_config) - model.load_weights(checkpoint_path, strict=False) + + assert ( + tokenizer.vocab_size == weights["model.tok_embeddings.weight"].shape[0] + ), "Embedding shape mismatch. Ensure that you are loading the demo-specific checkpoint." + + model.load_weights(list(weights.items()), strict=False) model.eval() if args.quantize: - nn.quantize(model.model, group_size=64, bits=8) + nn.quantize(model.model, group_size=32, bits=8) logger.info( f"Finished initializing model - took {time.time() - init_start_time_s:.4f} seconds" @@ -379,6 +489,8 @@ def recalc_dur_tokens_chunked( next_logits = logits[:, priming_len - idx] + logger.debug(f"Internal KV-state: {tokenizer.decode(model.get_kv_ctx())}") + return enc_seq, priming_seq, next_logits @@ -393,11 +505,12 @@ def decode_first_tokens( ): logger = get_logger("GENERATE") - # buffer_ms determines how far in the past to start generating notes - buffer_ms = FIRST_ONSET_BUFFER_MS - HARDWARE_OUTPUT_LATENCY_MS + # buffer_ms determines how far in the past to start generating notes. + buffer_ms = FIRST_ONSET_BUFFER_MS time_tok_id = tokenizer.tok_to_id[tokenizer.time_tok] eos_tok_id = tokenizer.tok_to_id[tokenizer.eos_tok] dim_tok_id = tokenizer.tok_to_id[tokenizer.dim_tok] + ped_off_id = tokenizer.tok_to_id[tokenizer.ped_off_tok] logits = first_token_logits time_since_first_onset_ms = get_epoch_time_ms() - first_on_msg_epoch_ms @@ -425,6 +538,7 @@ def decode_first_tokens( logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") + logits[:, tokenizer.tok_to_id[tokenizer.ped_off_tok]] = float("-inf") # MLX doesn't have a equivalent of torch topk log_probs = nn.log_softmax(logits, axis=-1) @@ -444,11 +558,19 @@ def decode_first_tokens( logger.debug(f"Calculated top {BEAM_WIDTH} tokens={top_toks}") logger.debug(f"Calculated top {BEAM_WIDTH} scores={top_log_probs.tolist()}") - masked_onset_ids = [ - tokenizer.tok_to_id[tok] - for tok in tokenizer.onset_tokens - if tok[1] < ((time_since_first_onset_ms + buffer_ms) % 5000) - ] + priming_seq_last_onset_ms = tokenizer.calc_length_ms( + priming_seq, onset=True + ) + + if priming_seq_last_onset_ms < time_since_first_onset_ms + buffer_ms: + masked_onset_ids = [ + tokenizer.tok_to_id[tok] + for tok in tokenizer.onset_tokens + if tok[1] < ((time_since_first_onset_ms + buffer_ms) % 5000) + ] + + else: + masked_onset_ids = [] logger.debug( f"Masking onsets for {len(masked_onset_ids)} tokens ({time_since_first_onset_ms + buffer_ms})" @@ -470,9 +592,13 @@ def decode_first_tokens( ) next_log_probs = nn.log_softmax(next_logits, axis=-1) - next_log_probs[:, masked_onset_ids] = float("-inf") + next_log_probs[:, eos_tok_id] = float("-inf") next_log_probs[:, dim_tok_id] = float("-inf") + next_log_probs[:, ped_off_id] = float("-inf") + + if masked_onset_ids: + next_log_probs[:, masked_onset_ids] = float("-inf") if tok_id == time_tok_id: next_log_probs[:, time_tok_id] = float("-inf") @@ -513,9 +639,7 @@ def decode_first_tokens( logger.info( f"Updated KV-Cache by re-inserting {best_tok_1} at position {idx-1}" ) - logger.info( - f"Inserted {best_tok_2} at position {idx} without updating KV-Cache" - ) + logger.debug(f"Internal KV-state: {tokenizer.decode(model.get_kv_ctx())}") return enc_seq, idx + 1 @@ -539,6 +663,13 @@ def decode_tokens( if control_sentinel.is_set(): control_sentinel.clear() + last_tok_is_pedal = False + dur_ids = [tokenizer.tok_to_id[idx] for idx in tokenizer.dur_tokens] + dur_mask_ids = [ + tokenizer.tok_to_id[("dur", dur_ms)] + for dur_ms in range(0, MIN_NOTE_LENGTH_MS, 10) + ] + while (not control_sentinel.is_set()) and idx < MAX_SEQ_LEN: decode_one_start_time_s = time.time() prev_tok_id = enc_seq[0, idx - 1] @@ -554,16 +685,18 @@ def decode_tokens( f"Sampled logits for positions {idx} by inserting {prev_tok} at position {idx-1}" ) + logits[:, tokenizer.tok_to_id[tokenizer.ped_off_tok]] += 3 # Manual adj logits[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") + + logits[:, dur_mask_ids] = float("-inf") + if last_tok_is_pedal is True: + logits[:, dur_ids] = float("-inf") + if is_ending is False: logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") - for dur_ms in range(0, MIN_NOTE_LEN_MS, 10): - logits[:, tokenizer.tok_to_id[("dur", dur_ms)]] = float("-inf") - if temperature > 0.0: - probs = mx.softmax(logits / temperature, axis=-1) - next_token_ids = sample_min_p(probs, min_p).flatten() + next_token_ids = sample_min_p(logits, min_p).flatten() else: next_token_ids = mx.argmax(logits, axis=-1).flatten() @@ -573,6 +706,11 @@ def decode_tokens( f"({(time.time() - decode_one_start_time_s)*1000:.2f}ms) {idx}: {next_token}" ) + if next_token in {tokenizer.ped_on_tok, tokenizer.ped_off_tok}: + last_tok_is_pedal = True + elif isinstance(next_token, tuple) and next_token[0] == "piano": + last_tok_is_pedal = False + if next_token == tokenizer.eos_tok: logger.info("EOS token produced") generated_tokens_queue.put(next_token) @@ -603,7 +741,9 @@ def generate_tokens( generate_start_s = time.time() priming_seq_len = len(priming_seq) - start_idx = max(2, priming_seq_len - 4 * num_preceding_active_pitches - 1) + start_idx = max( + 2, priming_seq_len - 3 * (num_preceding_active_pitches + 2) - 1 + ) enc_seq = mx.array( [ tokenizer.encode( @@ -676,6 +816,130 @@ def generate_tokens( ) +def _adjust_previous_off_time( + pitch_to_prev_msg: dict, + key: str | int, + new_on_send_time: int, + min_delta_ms: int, + logger: logging.Logger, +): + prev_on, prev_off = pitch_to_prev_msg.get(key, (None, None)) + + if prev_on is not None and prev_off is not None and min_delta_ms > 0: + adj_send_off_time = max( + min( + prev_off["send_epoch_time_ms"], + new_on_send_time - min_delta_ms, + ), + prev_on[ + "send_epoch_time_ms" + ], # Don't move prev_off before prev_on + ) + if adj_send_off_time != prev_off["send_epoch_time_ms"]: + logger.debug(f"Adjusting {prev_off}: t={adj_send_off_time}") + prev_off["send_epoch_time_ms"] = adj_send_off_time + prev_off["adjusted"] = True + + +# TODO: Verify that only ON -> OFF sequences are possible in tokenizer +def _decode_pedal_double( + note_buffer: list, + first_on_msg_epoch_ms: int, + num_time_toks: int, + pitch_to_prev_msg: dict, + outbound_midi_msg_queue: queue.Queue, + logger: logging.Logger, + tokenizer: AbsTokenizer, +): + pedal_tok, onset_tok = note_buffer + velocity = 127 if pedal_tok == tokenizer.ped_on_tok else 0 + _, onset = onset_tok + + onset_epoch_ms = first_on_msg_epoch_ms + (num_time_toks * 5000) + onset + send_onset_epoch_ms = onset_epoch_ms - BASE_OUTPUT_LATENCY_MS + pedal_msg = { + "pitch": "pedal", + "vel": velocity, + "epoch_time_ms": onset_epoch_ms, + "send_epoch_time_ms": send_onset_epoch_ms, + "uuid": "pedal", # All pedals have the same id + } + + if pedal_tok == tokenizer.ped_on_tok: + _adjust_previous_off_time( + pitch_to_prev_msg=pitch_to_prev_msg, + key="pedal", + new_on_send_time=send_onset_epoch_ms, + min_delta_ms=MIN_PEDAL_DELTA_MS, + logger=logger, + ) + pitch_to_prev_msg["pedal"] = (pedal_msg, None) + + elif pedal_tok == tokenizer.ped_off_tok: + prev_on, _ = pitch_to_prev_msg.get("pedal", (None, None)) + pitch_to_prev_msg["pedal"] = (prev_on, pedal_msg) + + outbound_midi_msg_queue.put(pedal_msg) + logger.debug(f"Put message: {pedal_msg}") + logger.debug(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") + + return onset_epoch_ms + + +def _decode_note_triple( + note_buffer: list, + first_on_msg_epoch_ms: int, + num_time_toks: int, + pitch_to_prev_msg: dict, + outbound_midi_msg_queue: queue.Queue, + logger: logging.Logger, +): + note_tok, onset_tok, dur_tok = note_buffer + _, pitch, vel = note_tok + _, onset = onset_tok + _, dur = dur_tok + + _uuid = uuid.uuid4() + onset_epoch_ms = first_on_msg_epoch_ms + (num_time_toks * 5000) + onset + offset_epoch_ms = onset_epoch_ms + dur + send_onset_epoch_ms = onset_epoch_ms - _get_input_latency_ms(vel) + send_offset_epoch_ms = offset_epoch_ms - BASE_OUTPUT_LATENCY_MS + + on_msg = { + "pitch": pitch, + "vel": vel, + "epoch_time_ms": onset_epoch_ms, + "send_epoch_time_ms": send_onset_epoch_ms, + "uuid": _uuid, + } + off_msg = { + "pitch": pitch, + "vel": 0, + "epoch_time_ms": offset_epoch_ms, + "send_epoch_time_ms": send_offset_epoch_ms, + "uuid": _uuid, + } + + _adjust_previous_off_time( + pitch_to_prev_msg=pitch_to_prev_msg, + key=pitch, + new_on_send_time=send_onset_epoch_ms, + min_delta_ms=MIN_NOTE_DELTA_MS, + logger=logger, + ) + + pitch_to_prev_msg[pitch] = (on_msg, off_msg) + + outbound_midi_msg_queue.put(on_msg) + outbound_midi_msg_queue.put(off_msg) + logger.debug(f"Put message: {on_msg}") + logger.debug(f"Put message: {off_msg}") + logger.debug(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") + + return offset_epoch_ms + + +# TODO: Refactor this method to prettify it def decode_tokens_to_midi( generated_tokens_queue: queue.Queue, outbound_midi_msg_queue: queue.Queue, @@ -703,13 +967,15 @@ def decode_tokens_to_midi( while True: tok = generated_tokens_queue.get() if tok is tokenizer.eos_tok: + # pitch=-1 is interpreted as the end message by stream_midi _uuid = uuid.uuid4() end_msg = { "pitch": -1, "vel": -1, - "epoch_time_ms": offset_epoch_ms + 100, # Last note offset + "epoch_time_ms": offset_epoch_ms + 100, + "send_epoch_time_ms": offset_epoch_ms + 100, "uuid": _uuid, - } # pitch=-1 denotes end_msg + } outbound_midi_msg_queue.put(end_msg) logger.info(f"Seen exit signal: EOS token") logger.debug(f"Put message: {end_msg}") @@ -723,6 +989,15 @@ def decode_tokens_to_midi( note_buffer.append(tok) if isinstance(tok, tuple) and tok[0] == "dur": + msg_type = "note" + break + elif ( + isinstance(tok, tuple) + and tok[0] == "onset" + and note_buffer[-2] + in {tokenizer.ped_on_tok, tokenizer.ped_off_tok} + ): + msg_type = "pedal" break while note_buffer and note_buffer[0] == tokenizer.time_tok: @@ -730,53 +1005,57 @@ def decode_tokens_to_midi( num_time_toks += 1 note_buffer.pop(0) - assert len(note_buffer) == 3 + assert len(note_buffer) in {2, 3}, f"Generation error: buffer={note_buffer}" # fmt: skip + logger.debug(f"Decoded note: {note_buffer}") - note_tok, onset_tok, dur_tok = note_buffer - _, pitch, vel = note_tok - _, onset = onset_tok - _, dur = dur_tok - - _uuid = uuid.uuid4() - onset_epoch_ms = first_on_msg_epoch_ms + (num_time_toks * 5000) + onset - offset_epoch_ms = onset_epoch_ms + dur - on_msg = { - "pitch": pitch, - "vel": vel, - "epoch_time_ms": onset_epoch_ms, - "uuid": _uuid, - } - off_msg = { - "pitch": pitch, - "vel": 0, - "epoch_time_ms": offset_epoch_ms, - "uuid": _uuid, - } - # Not thread safe but in theory should be ok? - if pitch_to_prev_msg.get(pitch) is not None and MIN_NOTE_DELTA_MS > 0: - prev_on, prev_off = pitch_to_prev_msg.get(pitch) - adj_off_time = max( - min( - prev_off["epoch_time_ms"], - onset_epoch_ms - MIN_NOTE_DELTA_MS, - ), - prev_on["epoch_time_ms"], + if msg_type == "note": + offset_epoch_ms = _decode_note_triple( + note_buffer=note_buffer, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + num_time_toks=num_time_toks, + pitch_to_prev_msg=pitch_to_prev_msg, + outbound_midi_msg_queue=outbound_midi_msg_queue, + logger=logger, + ) + elif msg_type == "pedal": + offset_epoch_ms = _decode_pedal_double( + note_buffer=note_buffer, + first_on_msg_epoch_ms=first_on_msg_epoch_ms, + num_time_toks=num_time_toks, + pitch_to_prev_msg=pitch_to_prev_msg, + outbound_midi_msg_queue=outbound_midi_msg_queue, + logger=logger, + tokenizer=tokenizer, ) - if adj_off_time != prev_off["epoch_time_ms"]: - logger.debug(f"Adjusting {prev_off}: t={adj_off_time}") - prev_off["epoch_time_ms"] = adj_off_time - prev_off["adjusted"] = True + else: + raise ValueError - pitch_to_prev_msg[pitch] = [on_msg, off_msg] + note_buffer = [] - outbound_midi_msg_queue.put(on_msg) - outbound_midi_msg_queue.put(off_msg) - logger.debug(f"Put message: {on_msg}") - logger.debug(f"Put message: {off_msg}") - logger.debug(f"Ahead by {onset_epoch_ms - get_epoch_time_ms()}ms") - note_buffer = [] +def _create_mido_message( + msg_dict: dict, + channel: int, + time_delta_ms: int, +) -> mido.Message: + if msg_dict["pitch"] == "pedal": + return mido.Message( + "control_change", + control=64, + value=msg_dict["vel"], + channel=channel, + time=time_delta_ms, + ) + else: + # note-on or note-off + return mido.Message( + "note_on", + note=msg_dict["pitch"], + velocity=msg_dict["vel"], + channel=channel, + time=time_delta_ms, + ) def stream_midi( @@ -789,129 +1068,122 @@ def stream_midi( results_queue: queue.Queue, ): logger = get_logger("STREAM") - logger.info( - f"Sending generated messages on MIDI port: '{midi_output_port}'" - ) - logger.info( - f"Applying hardware output latency adjustment: {HARDWARE_OUTPUT_LATENCY_MS}ms" - ) + logger.info(f"Sending generated messages on port: '{midi_output_port}'") active_pitch_uuid = {} - is_pitch_active = {} - midi_msgs = [] + pending_msgs = [] + msgs_to_archive = [] with mido.open_output(midi_output_port) as midi_out: while not control_sentinel.is_set(): - while True: + while not inbound_midi_msg_queue.empty(): try: msg = inbound_midi_msg_queue.get_nowait() + if msg: + pending_msgs.append(msg) except queue.Empty: break - else: - logger.debug(f"Received message: {msg}") - midi_msgs.append(msg) - - midi_msgs = sorted( - midi_msgs, - key=lambda msg: ( - msg["epoch_time_ms"], - msg["vel"], - ), - ) - - if control_sentinel.is_set(): - break - while midi_msgs: - # Messages are sent HARDWARE_OUTPUT_LATENCY_MS early - latency_adjusted_epoch_time_ms = ( - get_epoch_time_ms() + HARDWARE_OUTPUT_LATENCY_MS - ) - msg = midi_msgs[0] + pending_msgs.sort(key=lambda m: (m["send_epoch_time_ms"], m["vel"])) - if ( - 0 - < latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] - <= MAX_STREAM_DELAY_MS - ): - if msg["pitch"] == -1: # End msg - control_sentinel.set() - break - - mido_msg = mido.Message( - "note_on", - note=msg["pitch"], - velocity=msg["vel"], - channel=0, - time=0, - ) - - if msg["vel"] > 0: - active_pitch_uuid[msg["pitch"]] = msg["uuid"] - should_send_midi_out = True - should_append_to_msgs = True - elif msg.get("adjusted", False) is True: - should_send_midi_out = True - should_append_to_msgs = False - else: - should_send_midi_out = ( - active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] - ) - should_append_to_msgs = should_send_midi_out - - if should_send_midi_out is True: - midi_out.send(mido_msg) - is_pitch_active[msg["pitch"]] = msg["vel"] != 0 - logger.info(f"Sent message: {mido_msg}") - if should_append_to_msgs is True: - mido_msg_with_time = copy.deepcopy(mido_msg) - mido_msg_with_time.channel = midi_stream_channel - mido_msg_with_time.time = max( - 0, - msg["epoch_time_ms"] - - last_channel_msg_epoch_time_ms, - ) - last_channel_msg_epoch_time_ms = msg["epoch_time_ms"] - msgs.append(mido_msg_with_time) - - midi_msgs.pop(0) + while pending_msgs: + curr_epoch_time_ms = get_epoch_time_ms() + msg = pending_msgs[0] + if msg["send_epoch_time_ms"] > curr_epoch_time_ms: + break elif ( - latency_adjusted_epoch_time_ms - msg["epoch_time_ms"] + curr_epoch_time_ms - msg["send_epoch_time_ms"] > MAX_STREAM_DELAY_MS ): - # Message occurs too far in the past - logger.debug( - f"Skipping message occurring too far ({latency_adjusted_epoch_time_ms - msg['epoch_time_ms']}ms) in the past: {msg}" - ) - midi_msgs.pop(0) - else: - # Message occurs in the future + logger.debug(f"Skipping stale message: {msg}") + pending_msgs.pop(0) + continue + + logger.debug(f"Processing: {msg}") + + # End signal + if msg["pitch"] == -1: + control_sentinel.set() break + should_send = False + should_archive = False + if msg["vel"] > 0: # note-on or pedal-on + active_pitch_uuid[msg["pitch"]] = msg["uuid"] + should_send = True + should_archive = True + else: # note-off or pedal-off (vel == 0) + if msg.get("adjusted", False): + should_send = True + should_archive = msg["pitch"] == "pedal" + elif active_pitch_uuid.get(msg["pitch"]) == msg["uuid"]: + should_send = True + should_archive = True + active_pitch_uuid.pop(msg["pitch"], None) + + if should_send: + mido_msg = _create_mido_message( + msg_dict=msg, channel=0, time_delta_ms=0 + ) + midi_out.send(mido_msg) + logger.info(f"Sent message: {mido_msg}") + + if should_archive: + msgs_to_archive.append(msg) + + pending_msgs.pop(0) + + if control_sentinel.is_set(): + break + time.sleep(0.005) - remaining_note_off_messages = [ + last_archive_time_ms = last_channel_msg_epoch_time_ms + msgs_to_archive.sort(key=lambda m: (m["epoch_time_ms"], m["vel"])) + + for msg in msgs_to_archive: + time_delta_ms = round(msg["epoch_time_ms"] - last_archive_time_ms) + mido_msg = _create_mido_message( + msg_dict=msg, + channel=midi_stream_channel, + time_delta_ms=time_delta_ms, + ) + msgs.append(mido_msg) + last_archive_time_ms = msg["epoch_time_ms"] + + logger.info("Sending final note-off messages for cleanup.") + remaining_off_msgs = [ msg - for msg in midi_msgs + for msg in pending_msgs if msg["vel"] == 0 + and msg["pitch"] != "pedal" and active_pitch_uuid.get(msg["pitch"]) == msg["uuid"] ] + remaining_off_msgs.sort(key=lambda m: (m["epoch_time_ms"])) - logger.info("Processing remaining note_off messages") - for msg in remaining_note_off_messages: - mido_msg = mido.Message( - "note_on", - note=msg["pitch"], - velocity=0, - channel=midi_stream_channel, - time=msg["epoch_time_ms"] - last_channel_msg_epoch_time_ms, + for msg in remaining_off_msgs: + mido_msg = _create_mido_message( + msg_dict=msg, channel=0, time_delta_ms=0 ) midi_out.send(mido_msg) - last_channel_msg_epoch_time_ms = msg["epoch_time_ms"] - msgs.append(mido_msg) - results_queue.put(msgs) + time_delta_ms = round(msg["epoch_time_ms"] - last_archive_time_ms) + archived_msg = _create_mido_message( + msg_dict=msg, + channel=midi_stream_channel, + time_delta_ms=time_delta_ms, + ) + msgs.append(archived_msg) + last_archive_time_ms = msg["epoch_time_ms"] + + midi_out.send( + mido.Message( + "control_change", control=64, value=0, channel=0, time=0 + ) + ) + + results_queue.put(msgs) def stream_msgs( @@ -928,11 +1200,19 @@ def stream_msgs( midi_stream_channel: int, is_ending: bool = False, ): + + logger = get_logger("STREAM") midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) + midi_dict.remove_redundant_pedals() priming_seq = tokenizer.tokenize(midi_dict=midi_dict, add_dim_tok=False) priming_seq = priming_seq[: priming_seq.index(tokenizer.eos_tok)] + if priming_seq[-2] == tokenizer.ped_off_tok: + # Final pedal-off is needed for tokenizer, but unneeded in tokenized sequence + logger.info("Removing final pedal_off from tokenized sequence") + priming_seq = priming_seq[:-2] + if is_ending is True: priming_seq.append(tokenizer.dim_tok) @@ -992,17 +1272,14 @@ def stream_msgs( "midi_stream_channel": midi_stream_channel, "results_queue": stream_midi_results_queue, }, - daemon=True, ) stream_midi_thread.start() generate_tokens_thread.join() decode_tokens_to_midi_thread.join() + stream_midi_thread.join() msgs = stream_midi_results_queue.get() - if is_ending is True: - stream_midi_thread.join() - return msgs @@ -1156,7 +1433,7 @@ def continuous_prefill( received_messages_queue: queue.Queue, prev_context: list[int], ): - tokenizer = AbsTokenizer() + tokenizer = AbsTokenizer(config_path=config_path) logger = get_logger("PREFILL") msg_cnt = 0 seen_sentinel = False @@ -1178,6 +1455,7 @@ def continuous_prefill( if msg_cnt >= 10: midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) + midi_dict.remove_redundant_pedals() if len(midi_dict.note_msgs) > 0: curr_context = tokenizer.encode( @@ -1203,10 +1481,10 @@ def capture_and_update_kv( msgs: list, prev_context: list, control_sentinel: threading.Event, + reset_sentinel: threading.Event, wait_for_close: bool, - midi_input_port: str, + midi_performance_queue: queue.Queue, midi_capture_channel: int, - midi_control_signal: int | None = None, first_msg_epoch_time_ms: int | None = None, ): received_messages_queue = queue.Queue() @@ -1214,11 +1492,11 @@ def capture_and_update_kv( capture_midi_thread = threading.Thread( target=capture_midi_input, kwargs={ - "midi_input_port": midi_input_port, + "midi_performance_queue": midi_performance_queue, "control_sentinel": control_sentinel, + "reset_sentinel": reset_sentinel, "received_messages_queue": received_messages_queue, "midi_capture_channel": midi_capture_channel, - "midi_control_signal": midi_control_signal, "first_msg_epoch_time_ms": first_msg_epoch_time_ms, "results_queue": results_queue, "wait_for_close": wait_for_close, @@ -1239,139 +1517,129 @@ def capture_and_update_kv( def capture_midi_input( - midi_input_port: str, + midi_performance_queue: queue.Queue, control_sentinel: threading.Event, + reset_sentinel: threading.Event, received_messages_queue: queue.Queue, midi_capture_channel: int, results_queue: queue.Queue, - midi_control_signal: int | None = None, first_msg_epoch_time_ms: int | None = None, wait_for_close: bool = False, ): logger = get_logger("CAPTURE") - active_pitches = set() first_on_msg_epoch_ms = None prev_msg_epoch_time_ms = first_msg_epoch_time_ms + pedal_down = False + pitches_held_down = set() + pitches_sustained_by_pedal = set() + + while not midi_performance_queue.empty(): + try: + midi_performance_queue.get_nowait() + except queue.Empty: + break - logger.info(f"Listening on MIDI port: '{midi_input_port}'") - logger.info(f"Ready to capture MIDI events") - - # Clear undesired buffered notes - with mido.open_input(midi_input_port) as midi_input: - while True: - msg = midi_input.receive(block=False) - if msg is None: - break - - with mido.open_input(midi_input_port) as midi_input: - if midi_control_signal is not None: - logger.info( - f"Commencing generation upon keypress or MIDI control: {midi_control_signal}" - ) - else: - logger.info(f"Commencing generation upon keypress") + logger.info("Listening for input") + logger.info("Commencing generation upon keypress or control signal") - while not control_sentinel.is_set() or ( - wait_for_close and active_pitches + while True: + epoch_time_ms = get_epoch_time_ms() + active_notes = pitches_held_down.union(pitches_sustained_by_pedal) + should_stop = not wait_for_close or not active_notes + if reset_sentinel.is_set() or ( + control_sentinel.is_set() and should_stop ): - msg = midi_input.receive(block=False) + break - if msg is None: - time.sleep(0.001) - continue + try: + msg = midi_performance_queue.get(block=True, timeout=0.01) + except queue.Empty: + continue - if prev_msg_epoch_time_ms is None: - msg_time_ms = 0 - else: - msg_time_ms = get_epoch_time_ms() - prev_msg_epoch_time_ms + if msg.is_meta or msg.type == "program_change": + continue - prev_msg_epoch_time_ms = get_epoch_time_ms() - msg.time = msg_time_ms - msg.channel = midi_capture_channel - logger.info(f"Received message: [{msg}]") + msg.channel = midi_capture_channel + if prev_msg_epoch_time_ms is None: + msg.time = 0 + else: + msg.time = epoch_time_ms - prev_msg_epoch_time_ms - if msg.is_meta is True or msg.type == "program_change": - continue + prev_msg_epoch_time_ms = epoch_time_ms + logger.info(f"Received message: [{msg}]") - if ( - msg.type == "note_on" and msg.velocity == 0 - ) or msg.type == "note_off": - active_pitches.discard(msg.note) - received_messages_queue.put(msg) - elif msg.type == "note_on" and msg.velocity > 0: + match msg.type: + case "note_on" if msg.velocity > 0: if first_on_msg_epoch_ms is None: first_on_msg_epoch_ms = ( get_epoch_time_ms() - HARDWARE_INPUT_LATENCY_MS ) - - active_pitches.add(msg.note) + pitches_held_down.add(msg.note) + if pedal_down: + pitches_sustained_by_pedal.add(msg.note) received_messages_queue.put(msg) - elif msg.type == "control_change" and msg.control == 64: + + case "note_off" | "note_on": + # Note-off + pitches_held_down.discard(msg.note) received_messages_queue.put(msg) - elif ( - msg.type == "control_change" - and msg.control == midi_control_signal - and msg.value > 0 - ): - control_sentinel.set() - logger.info("Control signal seen") - - logger.info(f"Active pitches: {active_pitches}") - num_active_pitches = len(active_pitches) - - if active_pitches: - pitch = active_pitches.pop() - msg = mido.Message( - type="note_on", - note=pitch, - velocity=0, - channel=midi_capture_channel, - time=get_epoch_time_ms() - prev_msg_epoch_time_ms, - ) - received_messages_queue.put(msg) - - while active_pitches: - pitch = active_pitches.pop() - msg = mido.Message( - type="note_on", - note=pitch, - velocity=0, - channel=midi_capture_channel, - time=0, - ) + + case "control_change" if msg.control == 64: + if msg.value >= 64: + pedal_down = True + pitches_sustained_by_pedal.update(pitches_held_down) + else: + pedal_down = False + pitches_sustained_by_pedal.clear() received_messages_queue.put(msg) - # Turn off pedal - msg = mido.Message( - type="control_change", + active_pitches = pitches_held_down.union(pitches_sustained_by_pedal) + num_active_pitches = len(active_pitches) + logger.info(f"Active pitches ({num_active_pitches}): {active_pitches}") + + time_offset = get_epoch_time_ms() - prev_msg_epoch_time_ms + for pitch in pitches_held_down: + note_off_msg = mido.Message( + "note_off", + note=pitch, + channel=midi_capture_channel, + time=time_offset, + ) + received_messages_queue.put(note_off_msg) + time_offset = 0 + + received_messages_queue.put( + mido.Message( + "control_change", control=64, value=0, channel=midi_capture_channel, time=0, ) - received_messages_queue.put(msg) - received_messages_queue.put(None) - results_queue.put((first_on_msg_epoch_ms, num_active_pitches)) + ) + + received_messages_queue.put(None) + results_queue.put((first_on_msg_epoch_ms, num_active_pitches)) def play_midi_file( midi_through_port: str, - midi_in_port: str, + midi_performance_queue: queue.Queue, midi_path: str, - currently_streaming_sentinel: threading.Event, + currently_generating_sentinel: threading.Event, + reset_sentinel: threading.Event, ): - def _send_delayed_message(port, msg): - port.send(msg) + def _send_delayed_message(_midi_performance_queue: queue.Queue, msg): + _midi_performance_queue.put(msg) logger.debug(f"SENT: {msg}") logger = get_logger("FILE") logger.info(f"Playing {midi_path} on through-port '{midi_through_port}'") - logger.info( - f"Simulating input to port '{midi_in_port}' with {HARDWARE_INPUT_LATENCY_MS}ms latency" - ) + logger.info(f"Simulating input with {HARDWARE_INPUT_LATENCY_MS}ms latency") - if MIN_NOTE_DELTA_MS > 0: + if BASE_OUTPUT_LATENCY_MS > 0: midi_dict = MidiDict.from_midi(midi_path) + midi_dict.remove_redundant_pedals() midi_dict.enforce_gaps(min_gap_ms=MIN_NOTE_DELTA_MS) mid = midi_dict.to_midi() else: @@ -1379,178 +1647,207 @@ def _send_delayed_message(port, msg): time.sleep(1) with mido.open_output(midi_through_port) as through_port: - with mido.open_output(midi_in_port) as in_port: - for msg in mid.play(): - if currently_streaming_sentinel.is_set() is False and not ( - msg.type == "control_change" and msg.control == 64 - ): - through_port.send(msg) + for msg in mid.play(): + if reset_sentinel.is_set(): + logger.debug("Exiting") + return - timer = threading.Timer( - interval=HARDWARE_INPUT_LATENCY_MS / 1000.0, - function=_send_delayed_message, - args=[in_port, msg], - ) - timer.start() + if currently_generating_sentinel.is_set() is False: + through_port.send(msg) + + timer = threading.Timer( + interval=HARDWARE_INPUT_LATENCY_MS / 1000.0, + function=_send_delayed_message, + args=[midi_performance_queue, msg], + ) + timer.start() def listen_for_keypress_control_signal( control_sentinel: threading.Event, - generate_ending_sentinel: threading.Event, + reset_sentinel: threading.Event, + currently_generating_sentinel: threading.Event, + back_and_forth: bool = False, ): logger = get_logger("KEYBOARD") - while True: - time.sleep(5) - _input = input() - logger.info(f'Keypress seen "{_input}"') - if _input == "": - control_sentinel.set() - else: - control_sentinel.set() - generate_ending_sentinel.set() - return + logger.info( + "Listening for keyboard input (Enter to start AI, any other key + Enter to reset)." + ) + + while not reset_sentinel.is_set(): + rlist, _, _ = select.select([sys.stdin], [], [], 0.01) + + if rlist: + _input = sys.stdin.readline().strip() + logger.info(f'Keypress seen "{_input}"') + + if _input == "": + if ( + currently_generating_sentinel.is_set() + and back_and_forth is False + ): + logger.info("Resetting (control)") + reset_sentinel.set() + control_sentinel.set() + else: + logger.info("Resetting (reset)") + reset_sentinel.set() + control_sentinel.set() + + logger.debug( + "Exiting keypress listener because reset_sentinel was set by another thread." + ) def _listen( - midi_input_port: str, + midi_control_queue: queue.Queue, + reset_sentinel: threading.Event, + currently_generating_sentinel: threading.Event, logger: logging.Logger, midi_control_signal: int | None = None, + midi_reset_control_signal: int | None = None, ): - logger.info("Listening...") - with mido.open_input(midi_input_port) as midi_input: - while True: - msg = midi_input.receive(block=False) - if msg is None: - time.sleep(0.01) - elif ( - msg.type == "control_change" - and msg.control == midi_control_signal - and msg.value >= 64 - ): - return + while not midi_control_queue.empty(): + try: + midi_control_queue.get_nowait() + except queue.Empty: + break + + logger.info( + f"Listening for takeover signal ({midi_control_signal}) and reset signal ({midi_reset_control_signal}) on control queue." + ) + seen_note_on = False + while not reset_sentinel.is_set(): + try: + msg = midi_control_queue.get(block=True, timeout=0.01) + except queue.Empty: + continue + + if msg.type == "note_on" and msg.velocity > 0: + seen_note_on = True + + should_return_signal = ( + seen_note_on or currently_generating_sentinel.is_set() + ) + if ( + msg.type == "control_change" + and msg.control == midi_control_signal + and msg.value >= 64 + and should_return_signal + ): + return midi_control_signal + elif ( + msg.type == "control_change" + and msg.control == midi_reset_control_signal + and msg.value >= 64 + and should_return_signal + ): + return midi_reset_control_signal def listen_for_midi_control_signal( - midi_input_port: str, + midi_control_queue: queue.Queue, control_sentinel: threading.Event, + reset_sentinel: threading.Event, + currently_generating_sentinel: threading.Event, midi_control_signal: int | None = None, + midi_reset_control_signal: int | None = None, + back_and_forth: bool = False, ): logger = get_logger("MIDI-CONTROL") - while True: - _listen( - midi_input_port=midi_input_port, + while not reset_sentinel.is_set(): + time.sleep(1) + signal_received = _listen( + midi_control_queue=midi_control_queue, + reset_sentinel=reset_sentinel, + currently_generating_sentinel=currently_generating_sentinel, midi_control_signal=midi_control_signal, + midi_reset_control_signal=midi_reset_control_signal, logger=logger, ) - control_sentinel.set() - logger.info("Seen MIDI control signal") - time.sleep(5) + if signal_received is not None: + logger.info(f"Seen MIDI control signal ({signal_received})") -def parse_args(): - argp = argparse.ArgumentParser() - argp.add_argument("--checkpoint", help="path to model checkpoint") - argp.add_argument("--midi_in", required=False, help="MIDI input port") - argp.add_argument("--midi_out", required=True, help="MIDI output port") - argp.add_argument( - "--midi_through", - required=False, - help="MIDI through port for received input", - ) - argp.add_argument( - "--midi_path", - required=False, - help="Use MIDI file instead of MIDI input port", - ) - argp.add_argument( - "--midi_control_signal", - type=int, - help="MIDI control change message for AI takeover", - ) - argp.add_argument( - "--temp", - help="sampling temperature value", - type=float, - required=False, - default=0.95, - ) - argp.add_argument( - "--min_p", - help="sampling min_p value", - type=float, - required=False, - default=0.03, - ) - argp.add_argument( - "--wait_for_close", - help="wait for note-offs before generating", - action="store_true", - ) - argp.add_argument( - "--quantize", - help="apply model quantize", - action="store_true", - ) - argp.add_argument( - "--save_path", - type=str, - required=False, - help="Path to save complete MIDI file", - ) + if signal_received == midi_reset_control_signal: + logger.info("Resetting (reset)") + reset_sentinel.set() + control_sentinel.set() + elif signal_received == midi_control_signal: + if ( + currently_generating_sentinel.is_set() + and back_and_forth is False + ): + logger.info("Resetting (control)") + reset_sentinel.set() + control_sentinel.set() - return argp.parse_args() + logger.debug("Exiting MIDI control listener") -def main(args): - args = parse_args() +# TODO: Debug, fix, and perhaps refactor the functionality for going back and forth +# - One idea is on resume, to wait to start the clock until the user plays. +def run( + model: TransformerLM, + midi_performance_queue: queue.Queue, + midi_control_queue: queue.Queue, + midi_through_port: str | None, + midi_out_port: str | None, + midi_path: str | None, + midi_save_path: str | None, + midi_control_signal: int, + midi_reset_control_signal: int, + reset_sentinel: threading.Event, + wait_for_close: bool, + temperature: float, + min_p: float, + back_and_forth: bool, +): logger = get_logger() - tokenizer = AbsTokenizer() - model = load_model(checkpoint_path=args.checkpoint) - model = compile_model(model=model) - - assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in - + tokenizer = AbsTokenizer(config_path=config_path) control_sentinel = threading.Event() - generate_ending_sentinel = threading.Event() currently_generating_sentinel = threading.Event() - if args.midi_through: - close_notes(args.midi_through) - if args.midi_out: - close_notes(args.midi_out) + if midi_through_port: + close_notes(midi_through_port) + if midi_out_port: + close_notes(midi_out_port) - if args.midi_path: - midi_input_port = "IAC Driver Bus 1" + if midi_path: play_file_thread = threading.Thread( target=play_midi_file, - args=( - args.midi_through, - midi_input_port, - args.midi_path, - currently_generating_sentinel, - ), - daemon=True, + kwargs={ + "midi_through_port": midi_through_port, + "midi_performance_queue": midi_performance_queue, + "midi_path": midi_path, + "currently_generating_sentinel": currently_generating_sentinel, + "reset_sentinel": reset_sentinel, + }, ) else: - midi_input_port = args.midi_in play_file_thread = None keypress_thread = threading.Thread( target=listen_for_keypress_control_signal, - args=[control_sentinel, generate_ending_sentinel], - daemon=True, + kwargs={ + "control_sentinel": control_sentinel, + "reset_sentinel": reset_sentinel, + "currently_generating_sentinel": currently_generating_sentinel, + "back_and_forth": back_and_forth, + }, ) midi_control_thread = threading.Thread( target=listen_for_midi_control_signal, kwargs={ - "midi_input_port": ( - args.midi_in if args.midi_in else midi_input_port - ), + "midi_control_queue": midi_control_queue, "control_sentinel": control_sentinel, - "midi_control_signal": args.midi_control_signal, + "reset_sentinel": reset_sentinel, + "currently_generating_sentinel": currently_generating_sentinel, + "midi_control_signal": midi_control_signal, + "midi_reset_control_signal": midi_reset_control_signal, + "back_and_forth": back_and_forth, }, - daemon=True, ) keypress_thread.start() midi_control_thread.start() @@ -1564,15 +1861,15 @@ def main(args): msgs=[], prev_context=[], control_sentinel=control_sentinel, - wait_for_close=args.wait_for_close, - midi_input_port=midi_input_port, - midi_control_signal=args.midi_control_signal, + reset_sentinel=reset_sentinel, + wait_for_close=wait_for_close, + midi_performance_queue=midi_performance_queue, midi_capture_channel=0, ) ) curr_midi_channel = 0 - while True: + while not reset_sentinel.is_set(): control_sentinel.clear() currently_generating_sentinel.set() msgs = stream_msgs( @@ -1580,23 +1877,28 @@ def main(args): tokenizer=tokenizer, msgs=msgs, prev_context=prev_context, - midi_output_port=args.midi_out, + midi_output_port=midi_out_port, first_on_msg_epoch_ms=first_on_msg_epoch_ms, control_sentinel=control_sentinel, - temperature=args.temp, - min_p=args.min_p, + temperature=temperature, + min_p=min_p, num_preceding_active_pitches=num_active_pitches, midi_stream_channel=curr_midi_channel, is_ending=False, ) + if midi_save_path: + logger.info(f"Saving result to {midi_save_path}") + midi = convert_msgs_to_midi(msgs=msgs) + midi.save(midi_save_path) + curr_midi_channel += 1 - if curr_midi_channel == 9: + if curr_midi_channel == 9: # Skip drum channel curr_midi_channel += 1 control_sentinel.clear() - if generate_ending_sentinel.is_set(): - break + if reset_sentinel.is_set(): + return else: currently_generating_sentinel.clear() msgs, prev_context, _, num_active_pitches = capture_and_update_kv( @@ -1604,33 +1906,178 @@ def main(args): msgs=msgs, prev_context=prev_context, control_sentinel=control_sentinel, - wait_for_close=args.wait_for_close, - midi_input_port=midi_input_port, - midi_control_signal=args.midi_control_signal, + reset_sentinel=reset_sentinel, + wait_for_close=wait_for_close, + midi_performance_queue=midi_performance_queue, midi_capture_channel=curr_midi_channel, first_msg_epoch_time_ms=first_on_msg_epoch_ms, ) - # Generate ending - msgs = stream_msgs( - model=model, - tokenizer=tokenizer, - msgs=msgs, - prev_context=prev_context, - midi_output_port=args.midi_out, - first_on_msg_epoch_ms=first_on_msg_epoch_ms, - control_sentinel=control_sentinel, - temperature=args.temp / 2, - min_p=args.min_p, - num_preceding_active_pitches=num_active_pitches, - midi_stream_channel=curr_midi_channel, - is_ending=True, + keypress_thread.join() + midi_control_thread.join() + if play_file_thread: + play_file_thread.join() + + +def insert_embedding( + model: TransformerLM, + embedding_model_checkpoint_path: str, + embedding_midi_path: str, +): + logger = get_logger() + logger.info(f"Loading embedding from {embedding_midi_path}") + emb = _get_embedding( + embedding_model_checkpoint_path=embedding_model_checkpoint_path, + embedding_midi_path=embedding_midi_path, ) + logger.info(f"Inserting embedding into context") + model.fill_condition_kv(mx.array([emb], dtype=DTYPE)) + + global EMBEDDING_OFFSET + EMBEDDING_OFFSET = 1 - if args.save_path: - logger.info(f"Saving result to {args.save_path}") - midi = convert_msgs_to_midi(msgs=msgs) - midi.save(args.save_path) + +def forward_midi_input_port( + midi_input_port: str, + midi_control_queue: queue.Queue, + midi_performance_queue: queue.Queue | None, +): + logger = get_logger("MIDI-FORWARD") + logger.info(f"Forwarding MIDI from port: '{midi_input_port}'") + + if midi_performance_queue is None: + logger.info( + f"MIDI file provided - only forwarding {midi_input_port} to control queue" + ) + + try: + with mido.open_input(midi_input_port) as midi_in: + while True: + msg = midi_in.receive(block=True) + if msg: + midi_control_queue.put(msg) + if midi_performance_queue is not None: + midi_performance_queue.put(msg) + + except (Exception, KeyboardInterrupt) as e: + logger.error(f"Error in MIDI forwarder: {e}") + finally: + logger.info("MIDI forwarder has shut down.") + + +def main(args): + logger = get_logger() + model = load_model(checkpoint_path=args.checkpoint) + model = warmup_model(model=model) + if args.embedding_checkpoint and args.embedding_midi_path: + insert_embedding( + model=model, + embedding_model_checkpoint_path=args.embedding_checkpoint, + embedding_midi_path=args.embedding_midi_path, + ) + + assert (args.midi_path and os.path.isfile(args.midi_path)) or args.midi_in + + logger.info(f"Available MIDI ports: {mido.get_output_names()}") + midi_performance_queue = queue.Queue() + midi_control_queue = queue.Queue() + + if args.midi_in: + forwarder_thread = threading.Thread( + target=forward_midi_input_port, + kwargs={ + "midi_input_port": args.midi_in, + "midi_control_queue": midi_control_queue, + "midi_performance_queue": ( + midi_performance_queue if args.midi_path is None else None + ), + }, + daemon=True, + ) + forwarder_thread.start() + + reset_sentinel = threading.Event() + while True: + run( + model=model, + midi_performance_queue=midi_performance_queue, + midi_control_queue=midi_control_queue, + midi_through_port=args.midi_through, + midi_out_port=args.midi_out, + midi_path=args.midi_path, + midi_save_path=args.save_path, + midi_control_signal=args.midi_control_signal, + midi_reset_control_signal=args.midi_reset_control_signal, + reset_sentinel=reset_sentinel, + wait_for_close=args.wait_for_close, + temperature=args.temp, + min_p=args.min_p, + back_and_forth=args.back_and_forth, + ) + reset_sentinel = threading.Event() + + +def playback(midi_path: str, midi_out: str, save_path: str | None = None): + # Mocks generated playback by streaming from a real MIDI file + + close_notes(midi_out) + starting_epoch_time_ms = get_epoch_time_ms() + tokenizer = AbsTokenizer(config_path=config_path) + tokens_queue = queue.Queue() + midi_messages_queue = queue.Queue() + stream_midi_results_queue = queue.Queue() + control_sentinel = threading.Event() + + midi_dict = MidiDict.from_midi(midi_path) + midi_dict.remove_redundant_pedals() + tokenized_sequence = tokenizer.tokenize( + midi_dict, + add_dim_tok=False, + remove_preceding_silence=False, + ) + tokenized_sequence = tokenized_sequence[ + tokenized_sequence.index(tokenizer.bos_tok) + 1 : + ] + + # Populate token queue synthetically + for tok in tokenized_sequence: + tokens_queue.put(tok) + + decode_tokens_to_midi_thread = threading.Thread( + target=decode_tokens_to_midi, + kwargs={ + "generated_tokens_queue": tokens_queue, + "outbound_midi_msg_queue": midi_messages_queue, + "tokenizer": tokenizer, + "first_on_msg_epoch_ms": starting_epoch_time_ms, + "priming_seq_last_onset_ms": 0, + }, + ) + decode_tokens_to_midi_thread.start() + + stream_midi_thread = threading.Thread( + target=stream_midi, + kwargs={ + "inbound_midi_msg_queue": midi_messages_queue, + "msgs": [], + "last_channel_msg_epoch_time_ms": starting_epoch_time_ms, + "midi_output_port": midi_out, + "control_sentinel": control_sentinel, + "midi_stream_channel": 0, + "results_queue": stream_midi_results_queue, + }, + ) + stream_midi_thread.start() + + decode_tokens_to_midi_thread.join() + stream_midi_thread.join() + msgs = stream_midi_results_queue.get() + mid = convert_msgs_to_midi(msgs) + + if save_path is not None: + mid.save(save_path) + + return msgs def close_notes(midi_out_port: str): @@ -1643,7 +2090,22 @@ def close_notes(midi_out_port: str): if __name__ == "__main__": args = parse_args() - try: - main(args) - except KeyboardInterrupt: - close_notes(args.midi_out) + if args.hardware: + set_calibration_settings(args.hardware) + + if args.playback is True: + # Playback only mode for testing + assert args.midi_path is not None, "Must provide midi_path" + try: + playback( + midi_path=args.midi_path, + midi_out=args.midi_out, + save_path=args.save_path, + ) + except KeyboardInterrupt: + close_notes(args.midi_out) + else: + try: + main(args) + except KeyboardInterrupt: + close_notes(args.midi_out) diff --git a/demo/hardware/c4dm-disklavier.json b/demo/hardware/c4dm-disklavier.json new file mode 100644 index 0000000..f7382fe --- /dev/null +++ b/demo/hardware/c4dm-disklavier.json @@ -0,0 +1,22 @@ +{ + "MIN_NOTE_DELTA_MS": 100, + "MIN_PEDAL_DELTA_MS": 100, + "MIN_NOTE_LENGTH_MS": 100, + "HARDWARE_INPUT_LATENCY_MS": 50, + "BASE_OUTPUT_LATENCY_MS": 50, + "VELOCITY_OUTPUT_LATENCY_MS": { + "120": 0, + "110": 0, + "100": 0, + "90": 4, + "80": 10, + "70": 30, + "60": 60, + "50": 85, + "40": 105, + "30": 130, + "20": 140, + "10": 155, + "0": 155 + } +}