diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index 5f34820..e28a56a 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -32,7 +32,7 @@ BEAM_WIDTH = 3 TIME_TOK_WEIGHTING = -5 -FIRST_ONSET_BUFFER_MS = -100 # Controls onset timing for first generated note +FIRST_ONSET_BUFFER_MS = -100 # Controls onset timing for first generated not # HARDWARE: Decoded logits are masked for durations < MIN_NOTE_LEN_MS # HARDWARE: Sends early off-msg if pitch is on MIN_NOTE_DELTA_MS before on-msg @@ -40,17 +40,17 @@ # C4DM Disklavier: # MIN_NOTE_DELTA_MS = 40 -# MIN_NOTE_LEN_MS = 50 +# MIN_NOTE_LEN_MS = 100 # HARDWARE_INPUT_LATENCY_MS = 50 -# HARDWARE_OUTPUT_LATENCY_MS = 150 +# HARDWARE_OUTPUT_LATENCY_MS = 120 # Pianoteq MIN_NOTE_DELTA_MS = 0 -MIN_NOTE_LEN_MS = 0 +MIN_NOTE_LEN_MS = 30 HARDWARE_INPUT_LATENCY_MS = 0 HARDWARE_OUTPUT_LATENCY_MS = 0 -MAX_STREAM_DELAY_MS = 25 +MAX_STREAM_DELAY_MS = 50 file_handler = logging.FileHandler("./demo.log", mode="w") file_handler.setLevel(logging.DEBUG) @@ -158,14 +158,15 @@ def _compile_prefill( compile_start_time_s = time.time() logger.info(f"Compiling prefill (chunk_size={chunk_size})") - for _ in range(5): + for idx in range(8): + start = idx * (MAX_SEQ_LEN - chunk_size) // 7 mx.eval( prefill( model, idxs=mx.ones([1, chunk_size], dtype=mx.int32), input_pos=mx.arange( - MAX_SEQ_LEN - (chunk_size + 1), - MAX_SEQ_LEN - 1, + start, + start + chunk_size, dtype=mx.int32, ), ) @@ -243,8 +244,8 @@ def compile_model(model: TransformerLM): model = _compile_decode_one(model=model, logger=logger) for chunk_size in list( { - PREFILL_CHUNK_SIZE_L, - PREFILL_CHUNK_SIZE, + # PREFILL_CHUNK_SIZE_L, + # PREFILL_CHUNK_SIZE, RECALC_DUR_PREFILL_CHUNK_SIZE, } ): @@ -269,9 +270,11 @@ def load_model( init_start_time_s = time.time() model = TransformerLM(model_config) model.load_weights(checkpoint_path, strict=False) - nn.quantize(model.model, group_size=64, bits=8) model.eval() + if args.quantize: + nn.quantize(model.model, group_size=64, bits=8) + logger.info( f"Finished initializing model - took {time.time() - init_start_time_s:.4f} seconds" ) @@ -1172,7 +1175,7 @@ def continuous_prefill( msgs.append(msg) msg_cnt += 1 - if msg_cnt >= 10 or seen_sentinel: + if msg_cnt >= 10: midi = convert_msgs_to_midi(msgs=msgs) midi_dict = MidiDict(**midi_to_dict(midi)) @@ -1484,6 +1487,11 @@ def parse_args(): help="wait for note-offs before generating", action="store_true", ) + argp.add_argument( + "--quantize", + help="apply model quantize", + action="store_true", + ) argp.add_argument( "--save_path", type=str,