Skip to content
Merged
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
5a1d899
demo
loubbrad Dec 31, 2024
cfae8ee
demo fix
loubbrad Jan 4, 2025
ea68e75
mess it all up agian
loubbrad Jan 4, 2025
d6a865b
demo finished
loubbrad Jan 5, 2025
977f54b
undo mistake
loubbrad Jan 6, 2025
877d6e0
update demo
loubbrad Jan 7, 2025
9a5c011
add prefill compile
loubbrad Jan 7, 2025
3dd44bf
add class finetuning
loubbrad Feb 20, 2025
7d99bad
add seq sep option to PretrainingDataset
loubbrad Feb 21, 2025
014f0b9
change from genre to composer
loubbrad Feb 21, 2025
6b53ba4
update emb eval scripts
loubbrad Feb 24, 2025
e6c1e2a
add explore script
loubbrad Feb 24, 2025
1a94cb4
add contrastive ft
loubbrad Feb 24, 2025
3dc6953
add missing changes
loubbrad Feb 24, 2025
32b832d
add loop
loubbrad Feb 24, 2025
68fe378
fix arg bug
loubbrad Feb 24, 2025
dae0b03
update eval
loubbrad Feb 27, 2025
06ef338
fix eval hang
loubbrad Feb 28, 2025
19f9795
add data aug
loubbrad Mar 5, 2025
2c086e1
fix data aug
loubbrad Mar 5, 2025
b524ab3
formalize eval
loubbrad Mar 8, 2025
87c82ac
eval scripts
loubbrad Mar 10, 2025
58d439f
fix range bug
loubbrad Mar 10, 2025
b238edc
add m3 only embeddings
loubbrad Mar 11, 2025
b485c04
update script for m3 embeddings
loubbrad Mar 11, 2025
bda70ac
update for pianist eval
loubbrad Mar 11, 2025
3fa2b29
add pianist8 dataset script
loubbrad Mar 11, 2025
4a7427e
adjust per file emb logic and update scripts
loubbrad Mar 12, 2025
c8cc7b8
update datasets/training/model scripts to support embedding conditioning
loubbrad Mar 14, 2025
626b5b4
add ft-dataset script
loubbrad Mar 14, 2025
93bcb22
change use embeddings train logic
loubbrad Mar 15, 2025
72160f7
fix model ft loading
loubbrad Mar 15, 2025
50b27d3
fix arg
loubbrad Mar 15, 2025
f9e15de
fix ddp model error
loubbrad Mar 15, 2025
43689b4
add pca
loubbrad Mar 18, 2025
46b9daf
keshav
loubbrad Mar 20, 2025
9aeafd2
keshav add args
loubbrad Mar 20, 2025
9ef4db9
fix keshav
loubbrad Mar 21, 2025
67d86bb
update sampling and demo
loubbrad May 22, 2025
654d9de
add looping and ending to demo
loubbrad May 23, 2025
77d27b5
push mlx imp for test
loubbrad May 25, 2025
445d484
fix sample script
loubbrad May 26, 2025
dc9fdcb
add continuous prefill and speculative duration calculation
loubbrad May 27, 2025
4f54e41
add off-msg streaming and fix timing alignment
loubbrad May 28, 2025
b16394c
fix early-off logic with dumb hack
loubbrad May 28, 2025
fc43f70
fix stream_midi logic
loubbrad May 29, 2025
ba835ff
port demo to mlx
loubbrad May 29, 2025
571c0a6
add script
loubbrad May 29, 2025
a0944a4
update mlx demo
loubbrad Jun 2, 2025
6e8aeab
partial tree refactor for release
loubbrad Jun 3, 2025
a37ba9c
add resid dropout to model
loubbrad Jun 3, 2025
91802df
import fix
loubbrad Jun 3, 2025
f689029
inference tree skeleton
loubbrad Jun 3, 2025
3491ae4
fix tree
loubbrad Jun 3, 2025
3614a8b
rm scripts
loubbrad Jun 3, 2025
97e2a5c
refactor entrypoint for generate
loubbrad Jun 3, 2025
1daac44
cfg conditioned generation refactored for torch_cuda
loubbrad Jun 3, 2025
479edc1
add mlx backend for conditioned generation
loubbrad Jun 4, 2025
9ba3a00
fix mlx backend for conditioned gen
loubbrad Jun 4, 2025
2252039
update cli flags to standard unix format
loubbrad Jun 4, 2025
5c0a435
migrate to pyproject.toml
loubbrad Jun 4, 2025
0878402
add toml
loubbrad Jun 4, 2025
f86435f
remove old plan
loubbrad Jun 4, 2025
d5f46b9
add README draft
loubbrad Jun 9, 2025
23343cc
update README
loubbrad Jun 9, 2025
e5b8777
rmv test_dataset
loubbrad Jun 10, 2025
38d0ff2
update README
loubbrad Jun 10, 2025
442cc7a
demo adjustments
loubbrad Jun 16, 2025
893495d
add input delay correction
loubbrad Jun 16, 2025
6832c02
update README
loubbrad Jun 19, 2025
4474f6a
Merge branch 'dev' of github.com:loubbrad/aria into dev
loubbrad Jun 19, 2025
abf8e21
add quantize option to demo
loubbrad Jun 30, 2025
a378720
delete
loubbrad Jun 30, 2025
a0fcd8b
Merge branch 'dev' of github.com:loubbrad/aria into dev
loubbrad Jun 30, 2025
6a237a8
merge upstream
loubbrad Jun 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions demo/demo_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,25 @@

BEAM_WIDTH = 3
TIME_TOK_WEIGHTING = -5
FIRST_ONSET_BUFFER_MS = -100 # Controls onset timing for first generated note
FIRST_ONSET_BUFFER_MS = -100 # Controls onset timing for first generated not

# HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS
# HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg
# HARDWARE: All messages are sent HARDWARE_OUTPUT_LATENCY_MS early

# C4DM Disklavier:
# MIN_NOTE_DELTA_MS = 40
# MIN_NOTE_LEN_MS = 50
# MIN_NOTE_LEN_MS = 100
# HARDWARE_INPUT_LATENCY_MS = 50
# HARDWARE_OUTPUT_LATENCY_MS = 150
# HARDWARE_OUTPUT_LATENCY_MS = 120

# Pianoteq
MIN_NOTE_DELTA_MS = 0
MIN_NOTE_LEN_MS = 0
MIN_NOTE_LEN_MS = 30
HARDWARE_INPUT_LATENCY_MS = 0
HARDWARE_OUTPUT_LATENCY_MS = 0

MAX_STREAM_DELAY_MS = 25
MAX_STREAM_DELAY_MS = 50

file_handler = logging.FileHandler("./demo.log", mode="w")
file_handler.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -158,14 +158,15 @@ def _compile_prefill(

compile_start_time_s = time.time()
logger.info(f"Compiling prefill (chunk_size={chunk_size})")
for _ in range(5):
for idx in range(8):
start = idx * (MAX_SEQ_LEN - chunk_size) // 7
mx.eval(
prefill(
model,
idxs=mx.ones([1, chunk_size], dtype=mx.int32),
input_pos=mx.arange(
MAX_SEQ_LEN - (chunk_size + 1),
MAX_SEQ_LEN - 1,
start,
start + chunk_size,
dtype=mx.int32,
),
)
Expand Down Expand Up @@ -243,8 +244,8 @@ def compile_model(model: TransformerLM):
model = _compile_decode_one(model=model, logger=logger)
for chunk_size in list(
{
PREFILL_CHUNK_SIZE_L,
PREFILL_CHUNK_SIZE,
# PREFILL_CHUNK_SIZE_L,
# PREFILL_CHUNK_SIZE,
RECALC_DUR_PREFILL_CHUNK_SIZE,
}
):
Expand All @@ -269,9 +270,11 @@ def load_model(
init_start_time_s = time.time()
model = TransformerLM(model_config)
model.load_weights(checkpoint_path, strict=False)
nn.quantize(model.model, group_size=64, bits=8)
model.eval()

if args.quantize:
nn.quantize(model.model, group_size=64, bits=8)

logger.info(
f"Finished initializing model - took {time.time() - init_start_time_s:.4f} seconds"
)
Expand Down Expand Up @@ -1172,7 +1175,7 @@ def continuous_prefill(
msgs.append(msg)
msg_cnt += 1

if msg_cnt >= 10 or seen_sentinel:
if msg_cnt >= 10:
midi = convert_msgs_to_midi(msgs=msgs)
midi_dict = MidiDict(**midi_to_dict(midi))

Expand Down Expand Up @@ -1484,6 +1487,11 @@ def parse_args():
help="wait for note-offs before generating",
action="store_true",
)
argp.add_argument(
"--quantize",
help="apply model quantize",
action="store_true",
)
argp.add_argument(
"--save_path",
type=str,
Expand Down