diff --git a/README.md b/README.md index 6a0e4b9..e2a6318 100644 --- a/README.md +++ b/README.md @@ -93,8 +93,8 @@ python ./demo/demo_mlx.py \ --hardware ./demo/hardware/c4dm-disklavier.json \ --midi_control_signal 67 \ --midi_reset_control_signal 66 \ - --temp 0.9 \ - --min_p 0.03 + --temp 0.85 \ + --min_p 0.05 ``` A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output. In this mode, you can initiate the model takeover by pressing the enter key. @@ -107,7 +107,7 @@ python ./demo/demo_mlx.py \ --midi_path ${MIDI_PATH} \ --midi_through \ --midi_out \ - --temp 0.9 \ + --temp 0.85 \ --min_p 0.05 ``` diff --git a/aria/inference/model_cuda.py b/aria/inference/model_cuda.py index 8dbfbd4..1819681 100644 --- a/aria/inference/model_cuda.py +++ b/aria/inference/model_cuda.py @@ -109,7 +109,6 @@ def _att_block( freqs_cis: torch.Tensor, mask: torch.Tensor, ): - q, k, v = self.mixed_qkv(x).split( [self.d_model, self.d_model, self.d_model], dim=-1 ) @@ -166,7 +165,7 @@ def fill_condition_kv(self, emb: torch.Tensor): assert self.model_config.emb_size is not None input_pos = torch.tensor([0], device=emb.device) - mask = self.causal_mask[None, None, input_pos] + mask = self.causal_mask[input_pos].unsqueeze(0).unsqueeze(0) freqs_cis = self.freqs_cis[input_pos] x = emb.unsqueeze(dim=1) @@ -182,7 +181,7 @@ def forward( ): assert self.freqs_cis is not None, "Caches must be initialized first" - mask = self.causal_mask[None, None, input_pos] + mask = self.causal_mask[input_pos].unsqueeze(0).unsqueeze(0) if pad_idxs is not None: mask = mask & ~(pad_idxs.unsqueeze(1).unsqueeze(1)) diff --git a/demo/demo_mlx.py b/demo/demo_mlx.py index 9e08dbf..37f4f2f 100644 --- a/demo/demo_mlx.py +++ b/demo/demo_mlx.py @@ -696,7 +696,7 @@ def decode_tokens( logits[:, tokenizer.tok_to_id[tokenizer.eos_tok]] = float("-inf") if temperature > 0.0: - next_token_ids = sample_min_p(logits, min_p).flatten() + next_token_ids = sample_min_p(logits / temperature, min_p).flatten() else: next_token_ids = mx.argmax(logits, axis=-1).flatten() diff --git a/example-prompts/smooth_jazz.mid b/example-prompts/smooth_jazz.mid index 73f31a9..36aaa3c 100644 Binary files a/example-prompts/smooth_jazz.mid and b/example-prompts/smooth_jazz.mid differ