Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mlx_lm: Add Streaming Capability to Generate Function #807

Merged
merged 2 commits into from
Jun 3, 2024

Conversation

kurcontko
Copy link
Contributor

Add streaming feature to text generation function generate

  • Implemented an internal generator _generate to handle both streaming and non-streaming text generation modes.
  • Updated the generate function to return a generator when streaming is enabled, and the full generated text when streaming is disabled.

@kurcontko kurcontko changed the title mlx_lm: Add Streaming Capability to Text Generation Function mlx_lm: Add Streaming Capability to Generate Function May 31, 2024
@awni
Copy link
Member

awni commented Jun 1, 2024

Just curious about this. If you want a streaming generator you could use generate_step directly. It does almost everything generate does but without the logging and detokenization. Did you consider that? If not, why not?

I think it might be better rather than having two generators to figure out the right interface for the generator we already have (generate_step which we could rename to something more useful) and then use it in the main generate.

@dfl
Copy link

dfl commented Jun 3, 2024

@awni can you recommend please a way to support streaming without needing to patch mlx_lm? I started this PR which substitutes bitsandbytes for MLX, however is not working with streaming. lllyasviel/Omost#54

@kurcontko
Copy link
Contributor Author

kurcontko commented Jun 3, 2024

Just curious about this. If you want a streaming generator you could use generate_step directly. It does almost everything generate does but without the logging and detokenization. Did you consider that? If not, why not?

I think it might be better rather than having two generators to figure out the right interface for the generator we already have (generate_step which we could rename to something more useful) and then use it in the main generate.

Yes, I considered this. However, generate_step is currently not public in mlx_lm. You have to patch the __init__.py or use from mlx_lm.utils import generate_step.

The second issue, as you already mentioned, is that you then have to handle tokenization and detokenization yourself. You also need to add mx in imports, and you lose logging.

In my opinion, it’s a bit too much overhead on the end user’s side to enable streaming this way. It should be easily enabled like in the OpenAI API by a parameter passed to the method or just another separate function.

I know that my solution is not perfect and it might be refactored.

Here is the comparison between the generate_step and the PR code solution from the end-user side:

import mx

prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
detokenizer.reset()

for token in (
    generate_step(
        prompt_tokens,
        model,
        temp,
        repetition_penalty,
        repetition_context_size,
        top_p,
        logit_bias,
    )
):
    if token == tokenizer.eos_token_id:
        break
    detokenizer.add_token(token)

    yield detokenizer.last_segment

vs.

response_generator = generate(
    model, 
    tokenizer, 
    prompt=prompt, 
    max_tokens=max_tokens, 
    temp=temp, 
    streaming=True
)

for token in response_generator:
    yield token

@awni
Copy link
Member

awni commented Jun 3, 2024

Yes this makes sense. However, the generate function is becoming unreadable and too complex (no fault of yours, it's just hard to pack so many conditions into one function). I am going to try refactoring into two separate functions both of which use generate_step under the hood to see if that is workable. A little code duplication may be worth it if it makes the functions easier to follow.

@awni
Copy link
Member

awni commented Jun 3, 2024

I modified the code to have two functions:

  • generate (as it was)
  • stream_generate: A simplified version which generates a stream of output text. Updated the documentation. So this works now:
from mlx_lm import load, stream_generate

repo = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
model, tokenizer = load(repo)

prompt = "Write a story about Einstein"

for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
    print(t, end="", flush=True)
print()

I think there is some opportunity to refactor generate to use stream_generate but I will hold off on that until we understand the usage patterns a bit better.

@awni
Copy link
Member

awni commented Jun 3, 2024

@dfl the stream_generate function should work for you. If it doesn't please let us know what's missing.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the addition!

@awni awni merged commit 43d6deb into ml-explore:main Jun 3, 2024
4 checks passed
@dfl
Copy link

dfl commented Jun 3, 2024

@awni I am trying to match the transformers API, which uses TextIteratorStreamer (as well as AutoModelForCausalLM and AutoTokenizer). I guess the APIs are too different and it's not going to work so easily. The AutoModelForCausalLM.generate and mlx's generate functions do not have matching arguments, so I need some wrapper function... also with a lambda to work as Thread for Gradio -- seems beyond my current python abilities 🫤

@dfl
Copy link

dfl commented Jun 3, 2024

sorry I don't understand... how do I get these latest features in my pip install?
I ran pip install mlx-lm --upgrade to 0.14.2 but still get this error:
ImportError: cannot import name 'stream_generate' from 'mlx_lm' (/opt/homebrew/Caskroom/miniconda/base/envs/omost/lib/python3.10/site-packages/mlx_lm/__init__.py)

@awni
Copy link
Member

awni commented Jun 3, 2024

@dfl I just did a patch release so you can get 0.14.3 which should have stream_generate

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants