Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 111 additions & 36 deletions backend/python/mlx/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
import grpc
from mlx_lm import load, generate, stream_generate
from mlx_lm.sample_utils import make_sampler
from mlx_lm.models.cache import make_prompt_cache
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
import mlx.core as mx
import base64
import io

from mlx_cache import ThreadSafeLRUPromptCache

_ONE_DAY_IN_SECONDS = 60 * 60 * 24

# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
Expand Down Expand Up @@ -118,10 +120,16 @@ async def LoadModel(self, request, context):
self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config)
else:
self.model, self.tokenizer = load(request.Model)

# Initialize prompt cache for efficient generation
max_kv_size = self.options.get("max_kv_size", None)
self.prompt_cache = make_prompt_cache(self.model, max_kv_size)

# Initialize thread-safe LRU prompt cache for efficient generation
max_cache_entries = self.options.get("max_cache_entries", 10)
self.max_kv_size = self.options.get("max_kv_size", None)
self.model_key = request.Model
self.lru_cache = ThreadSafeLRUPromptCache(
max_size=max_cache_entries,
can_trim_fn=can_trim_prompt_cache,
trim_fn=trim_prompt_cache,
)

except Exception as err:
print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr)
Expand All @@ -134,38 +142,57 @@ async def Predict(self, request, context):
"""
Generates text based on the given prompt and sampling parameters using MLX.

Uses thread-safe LRU prompt cache for efficient prefix reuse across requests.

Args:
request: The predict request.
context: The gRPC context.

Returns:
backend_pb2.Reply: The predict result.
"""
prompt_cache = None
cache_key = None

try:
# Prepare the prompt
prompt = self._prepare_prompt(request)

# Prepare the prompt and tokenize for cache key
prompt_text = self._prepare_prompt(request)
cache_key = self._get_tokens_from_prompt(prompt_text)

# Fetch nearest cache (exact, shorter prefix, or create new)
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
self.model_key, cache_key
)
if prompt_cache is None:
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
remaining_tokens = cache_key

# Build generation parameters using request attributes and options
max_tokens, sampler_params = self._build_generation_params(request)
print(f"Generating text with MLX - max_tokens: {max_tokens}, sampler_params: {sampler_params}", file=sys.stderr)

print(f"Generating text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)

# Create sampler with parameters
sampler = make_sampler(**sampler_params)

# Generate text using MLX with proper parameters
response = generate(

# Use stream_generate to track generated tokens for cache key
generated_text = []
for response in stream_generate(
self.model,
self.tokenizer,
prompt=prompt,
prompt=remaining_tokens if remaining_tokens else cache_key,
max_tokens=max_tokens,
sampler=sampler,
prompt_cache=self.prompt_cache,
verbose=False
)

return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))

prompt_cache=prompt_cache,
):
generated_text.append(response.text)
cache_key.append(response.token)

# Insert completed cache
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)

return backend_pb2.Reply(message=bytes(''.join(generated_text), encoding='utf-8'))

except Exception as e:
print(f"Error in MLX Predict: {e}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
Expand Down Expand Up @@ -194,42 +221,65 @@ async def PredictStream(self, request, context):
"""
Generates text based on the given prompt and sampling parameters, and streams the results using MLX.

Uses thread-safe LRU prompt cache for efficient prefix reuse across requests.

Args:
request: The predict stream request.
context: The gRPC context.

Yields:
backend_pb2.Reply: Streaming predict results.
"""
prompt_cache = None
cache_key = None

try:
# Prepare the prompt
prompt = self._prepare_prompt(request)

# Prepare the prompt and tokenize for cache key
prompt_text = self._prepare_prompt(request)
cache_key = self._get_tokens_from_prompt(prompt_text)

# Fetch nearest cache (exact, shorter prefix, or create new)
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
self.model_key, cache_key
)
if prompt_cache is None:
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
remaining_tokens = cache_key

# Build generation parameters using request attributes and options
max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
print(f"Streaming text with MLX - max_tokens: {max_tokens}, sampler_params: {sampler_params}", file=sys.stderr)

print(f"Streaming text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)

# Create sampler with parameters
sampler = make_sampler(**sampler_params)

# Stream text generation using MLX with proper parameters
for response in stream_generate(
self.model,
self.tokenizer,
prompt=prompt,
prompt=remaining_tokens if remaining_tokens else cache_key,
max_tokens=max_tokens,
sampler=sampler,
prompt_cache=self.prompt_cache,
prompt_cache=prompt_cache,
):
cache_key.append(response.token)
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))

except Exception as e:
print(f"Error in MLX PredictStream: {e}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"Streaming generation failed: {str(e)}")
yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))

finally:
# Always insert cache, even on interruption
if prompt_cache is not None and cache_key is not None:
try:
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
except Exception as e:
print(f"Error inserting cache: {e}", file=sys.stderr)

def _prepare_prompt(self, request):
"""
Prepare the prompt for MLX generation, handling chat templates if needed.
Expand All @@ -246,16 +296,31 @@ def _prepare_prompt(self, request):
messages = []
for msg in request.Messages:
messages.append({"role": msg.role, "content": msg.content})

prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
messages,
tokenize=False,
add_generation_prompt=True
)
return prompt
else:
return request.Prompt

def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
"""
Tokenize prompt text for cache key generation.

Args:
prompt_text: The prompt string to tokenize.

Returns:
List[int]: List of token IDs.
"""
tokens = self.tokenizer.encode(prompt_text)
if hasattr(tokens, 'tolist'):
return tokens.tolist()
return list(tokens)




Expand Down Expand Up @@ -284,11 +349,19 @@ def _build_generation_params(self, request, default_max_tokens=200):
top_p = getattr(request, 'TopP', 0.0)
if top_p == 0.0:
top_p = 1.0 # Default top_p


min_p = getattr(request, 'MinP', 0.0)
# min_p default of 0.0 means disabled (no filtering)

top_k = getattr(request, 'TopK', 0)
# top_k default of 0 means disabled (no filtering)

# Initialize sampler parameters
sampler_params = {
'temp': temp,
'top_p': top_p,
'min_p': min_p,
'top_k': top_k,
'xtc_threshold': 0.0,
'xtc_probability': 0.0,
}
Expand All @@ -308,7 +381,9 @@ def _build_generation_params(self, request, default_max_tokens=200):
sampler_option_mapping = {
'temp': 'temp',
'temperature': 'temp', # alias
'top_p': 'top_p',
'top_p': 'top_p',
'min_p': 'min_p',
'top_k': 'top_k',
'xtc_threshold': 'xtc_threshold',
'xtc_probability': 'xtc_probability',
}
Expand Down
Loading
Loading