Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,27 +79,40 @@ Our embedding model was trained to capture composition-level and performance-lev

## Real-time demo

In `demo/` we provide CUDA (Linux/PyTorch) and MLX (Apple Silicon) implementations of the real-time interactive piano-continuation demo showcased in our release blog post. For the demo we used an acoustic Yamaha Disklavier piano with simultaneous MIDI input and output ports connected via a standard MIDI interface.
In `demo/` we provide an MLX (Apple Silicon) implementation of the real-time interactive piano-continuation demo showcased in our release blog post. In order to use the demo, you must download the demo-specific model checkpoint which enhances the model to additionally control the sustain pedal ([direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/model-demo.safetensors?download=true)).

❗**NOTE**: Responsiveness of the real-time demo is dependent on your system configuration, e.g., GPU FLOPS and memory bandwidth.
For our demonstration, we used an acoustic Yamaha Disklavier piano with simultaneous MIDI input and output ports connected via a standard MIDI interface. We disabled the built-in Disklavier playback mode, instead manually calibrating key-velocity latency to enhance responsiveness. You may recreate this in your own environment with our acoustic calibration settings, using the following script:

A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI, ALSA) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output.
❗**NOTE**: It is vital that you use the `latency=off`/`realtime` Disklavier playback setting when using the provided configuration for `--hardware`.

Example usage (MLX):
```bash
python ./demo/demo_mlx.py \
--checkpoint <checkpoint-path> \
--midi_in <midi-in-port> \
--midi_out <midi-out-port> \
--hardware ./demo/hardware/c4dm-disklavier.json \
--midi_control_signal 67 \
--midi_reset_control_signal 66 \
--temp 0.9 \
--min_p 0.03
```

A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output. In this mode, you can initiate the model takeover by pressing the enter key.

```bash
MIDI_PATH="example-prompts/pokey_jazz.mid"
MIDI_PATH="./example-prompts/smooth_jazz.mid"

python demo/demo_mlx.py \
python ./demo/demo_mlx.py \
--checkpoint <checkpoint-path> \
--midi_path ${MIDI_PATH} \
--midi_through <port-to-stream-midi-file-through> \
--midi_out <port-to-stream-generation-over> \
--save_path <path-to-save-result> \
--temp 0.98 \
--min_p 0.035
--midi_through <midi-playback-port> \
--midi_out <midi-playback-port> \
--temp 0.9 \
--min_p 0.03
```

❗**NOTE**: Responsiveness of the real-time demo is dependent on your system configuration, specifically GPU memory bandwidth.

## Evaluation

We provide the specific files/splits we used for Aria-MIDI derived linear-probe and classification evaluations. These can be downloaded from HuggingFace ([direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/eval-splits.tar.gz?download=true)). Class labels are provided in `metadata.json` with the schema:
Expand Down
14 changes: 12 additions & 2 deletions aria/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,21 @@ def get_inference_prompt(
for msg in midi_dict.note_msgs
if midi_dict.tick_to_ms(msg["data"]["start"]) <= prompt_len_ms
]
midi_dict.pedal_msgs = [
msg
for msg in midi_dict.pedal_msgs
if midi_dict.tick_to_ms(msg["tick"]) <= prompt_len_ms
]
if midi_dict.pedal_msgs and midi_dict.pedal_msgs[-1]["data"] == 1:
midi_dict.pedal_msgs.pop()

if len(midi_dict.note_msgs) == 0:
return [("prefix", "instrument", "piano"), tokenizer.bos_tok]

seq = tokenizer.tokenize(midi_dict=midi_dict, add_dim_tok=False)
seq.remove(tokenizer.eos_tok)
seq = tokenizer.tokenize(
midi_dict=midi_dict,
add_dim_tok=False,
add_eos_tok=False,
)

return seq
56 changes: 50 additions & 6 deletions aria/inference/model_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,30 +84,41 @@ def __call__(
self,
x: mx.array,
input_pos: mx.array,
max_kv_pos: int | None,
offset: int,
mask: mx.array,
):
assert self.kv_cache is not None, "Cache not initialized"

x += self._att_block(
x=self.norm1(x),
input_pos=input_pos,
max_kv_pos=max_kv_pos,
offset=offset,
mask=mask,
)
x = x + self._ff_block(self.norm2(x))

return x

def get_kv(self, k: mx.array, v: mx.array, input_pos: mx.array):
def get_kv(
self,
k: mx.array,
v: mx.array,
input_pos: mx.array,
max_kv_pos: int | None,
):
k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=input_pos)

return k, v
if max_kv_pos is not None:
return k[:, :, : max_kv_pos + 1, :], v[:, :, : max_kv_pos + 1, :]
else:
return k, v

def _att_block(
self,
x: mx.array,
input_pos: mx.array,
max_kv_pos: int | None,
offset: int,
mask: mx.array,
):
Expand All @@ -124,7 +135,8 @@ def _att_block(
k = apply_rotary_emb_mlx(k, offset=offset)
q, k, v = map(lambda x: x.transpose(0, 2, 1, 3), (q, k, v))

k, v = self.get_kv(k, v, input_pos=input_pos)
k, v = self.get_kv(k, v, input_pos=input_pos, max_kv_pos=max_kv_pos)

wv = mx.fast.scaled_dot_product_attention(
q=q,
k=k,
Expand Down Expand Up @@ -159,6 +171,7 @@ def __init__(self, model_config: ModelConfig):
TransformerBlock(model_config) for _ in range(model_config.n_layers)
]
self.out_layer_norm = nn.LayerNorm(model_config.d_model)
self.kv_ctx = None

def fill_condition_kv(self, emb: mx.array):
assert self.causal_mask is not None, "Caches must be initialized first"
Expand All @@ -177,20 +190,30 @@ def __call__(
self,
idxs: mx.array,
input_pos: mx.array,
max_kv_pos: int,
offset: int,
pad_idxs: mx.array | None = None,
_debug_track_kv: bool = False,
):
assert self.causal_mask is not None, "Caches must be initialized first"

mask = self.causal_mask[None, None, input_pos]
if self.kv_ctx is None:
self.kv_ctx = mx.full(
self.model_config.max_seq_len, 3
) # unk_tok id

if _debug_track_kv is True:
self.kv_ctx[input_pos] = idxs
self.kv_ctx[input_pos[-1].item() + 1 :] = 3

mask = self.causal_mask[None, None, input_pos, : max_kv_pos + 1]
if pad_idxs is not None:
pad_mask = mx.expand_dims(mx.expand_dims(pad_idxs, axis=1), axis=1)
mask = mask & ~pad_mask

x = self.tok_embeddings(idxs)
for layer in self.encode_layers:
x = layer(x, input_pos, offset, mask)
x = layer(x, input_pos, max_kv_pos, offset, mask)

x = self.out_layer_norm(x)

Expand All @@ -217,11 +240,13 @@ def __call__(
idxs: mx.array,
input_pos: mx.array,
offset: int,
max_kv_pos: int | None = None,
pad_idxs: mx.array | None = None,
):
hidden_states = self.model(
idxs=idxs,
input_pos=input_pos,
max_kv_pos=max_kv_pos,
offset=offset,
pad_idxs=pad_idxs,
)
Expand All @@ -235,6 +260,25 @@ def fill_condition_kv(self, cond_emb: mx.array):
adapted_emb = self.embedding_adapter(cond_emb)
self.model.fill_condition_kv(emb=adapted_emb)

def reset_kv_ctx(self):
self.model.kv_ctx = None

def get_kv_ctx(self):
# Used for debugging kv-cache validation
_kv_ctx = self.model.kv_ctx

match self.model.kv_ctx:
case None:
return None
case mx.array():
_kv_ctx = self.model.kv_ctx.tolist()
if 3 in _kv_ctx:
return _kv_ctx[: _kv_ctx.index(3)]
else:
return _kv_ctx
case _:
raise ValueError

def setup_cache(
self,
batch_size,
Expand Down
9 changes: 0 additions & 9 deletions aria/inference/sample_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,6 @@
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16


def get_cfg_prompt(prompts: list):
cfg_prompts = []
for prompt in prompts:
cfg_prompts.append(prompt)
cfg_prompts.append(prompt)

return cfg_prompts


@torch.inference_mode()
def decode_one(
model: TransformerLM,
Expand Down
15 changes: 7 additions & 8 deletions aria/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def forward(
seq_len=self.model_config.max_seq_len,
n_elem=self.model_config.d_model // self.model_config.n_heads,
base=500000,
dtype=hidden_states.dtype,
).to(src.device)
freqs_cis = self.freqs_cis[: src.shape[1]]

Expand Down Expand Up @@ -379,7 +378,6 @@ def precompute_freqs_cis(
seq_len: int,
n_elem: int,
base: int = 500000,
dtype: torch.dtype = torch.bfloat16,
):
freqs = 1.0 / (
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
Expand All @@ -389,22 +387,23 @@ def precompute_freqs_cis(
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)

return cache.to(dtype=dtype)
return cache


@torch.jit.script
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""
In-place RoPE. Credits to Katherine Crowson:
x shape (b_sz, s_len, n_head, d_head).
cos, sin shape (s_len, d_head // 2).
freqs_cis shape (s_len, d_head // 2, 2) and is float32.
"""

d = x.shape[-1] // 2
x_float = x.float()
freqs_cis = freqs_cis.detach()
d = x_float.shape[-1] // 2
cos = freqs_cis[..., 0][None, :, None]
sin = freqs_cis[..., 1][None, :, None]
x1, x2 = x[..., :d], x[..., d : d * 2]
x1, x2 = x_float[..., :d], x_float[..., d : d * 2]
tmp = x1.clone()
x1.mul_(cos).addcmul_(x2, sin, value=-1)
x2.mul_(cos).addcmul_(tmp, sin, value=1)
return x
return x.copy_(x_float)
7 changes: 4 additions & 3 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def generate(args):
args.prompt_midi_path,
prompt_duration_s=prompt_duration_s,
)
print(prompt)
max_new_tokens = min(8096 - len(prompt), max_new_tokens)

if backend == "torch_cuda":
Expand Down Expand Up @@ -317,13 +318,13 @@ def generate(args):


def _get_embedding(
embedding_model_checkpoints_path: str,
embedding_model_checkpoint_path: str,
embedding_midi_path: str,
):
from aria.embedding import get_global_embedding_from_midi

model = _load_embedding_model(
checkpoint_path=embedding_model_checkpoints_path
checkpoint_path=embedding_model_checkpoint_path
).cpu()
global_embedding = get_global_embedding_from_midi(
model=model,
Expand Down Expand Up @@ -353,7 +354,7 @@ def conditioned_generate(args):
prompt_duration_s=prompt_duration_s,
)
embedding = _get_embedding(
embedding_model_checkpoints_path=args.embedding_model_checkpoint_path,
embedding_model_checkpoint_path=args.embedding_model_checkpoint_path,
embedding_midi_path=args.embedding_midi_path,
)
max_new_tokens = min(8096 - len(prompt), max_new_tokens)
Expand Down
4 changes: 2 additions & 2 deletions aria/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def resume_train(
optimizer, scheduler = get_optim(
model,
num_epochs=epochs,
steps_per_epoch=len(train_dataloader),
steps_per_epoch=len(train_dataloader) // grad_acc_steps,
)

(
Expand Down Expand Up @@ -731,7 +731,7 @@ def train(
optimizer, scheduler = get_optim(
model,
num_epochs=epochs,
steps_per_epoch=len(train_dataloader),
steps_per_epoch=len(train_dataloader) // grad_acc_steps,
)

(
Expand Down
1 change: 1 addition & 0 deletions config/models/medium.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
"ff_mult": 4,
"drop_p": 0.0,
"max_seq_len": 8192,
"vocab_size": 17727,
"grad_checkpoint": true
}
Loading