Add a transformers client in contrib.#89
Conversation
| def _ensure_inference_task_started(self) -> None: | ||
| """Lazy-start the background inference task if not already running.""" | ||
| if self._inference_task is None or self._inference_task.done(): | ||
| # Create weakref to self so task can detect when client is GC'd | ||
| weak_self = weakref.ref(self) | ||
| self._inference_task = asyncio.create_task(self._inference_loop(weak_self)) | ||
| _LOGGER.info("TransformersLLMClient started background inference task") |
There was a problem hiding this comment.
Per CLAUDE.md async logging guidance, TransformersLLMClient started background inference task never logs id(self) or another client identifier, so if multiple clients run concurrently you cannot trace which instance started its background task; can we include the client id (e.g., id(self)) in this log?
| def _ensure_inference_task_started(self) -> None: | |
| """Lazy-start the background inference task if not already running.""" | |
| if self._inference_task is None or self._inference_task.done(): | |
| # Create weakref to self so task can detect when client is GC'd | |
| weak_self = weakref.ref(self) | |
| self._inference_task = asyncio.create_task(self._inference_loop(weak_self)) | |
| _LOGGER.info("TransformersLLMClient started background inference task") | |
| def _ensure_inference_task_started(self) -> None: | |
| """Lazy-start the background inference task if not already running.""" | |
| if self._inference_task is None or self._inference_task.done(): | |
| # Create weakref to self so task can detect when client is GC'd | |
| weak_self = weakref.ref(self) | |
| self._inference_task = asyncio.create_task(self._inference_loop(weak_self)) | |
| _LOGGER.info("TransformersLLMClient(id=%s) started background inference task", id(self)) |
Finding type: AI Coding Guidelines
There was a problem hiding this comment.
Not necessary, IMO
There was a problem hiding this comment.
The log currently just says TransformersLLMClient started background inference task, so when multiple clients are running there’s no way to tell which instance triggered it; including id(self) (or similar) gives us a lightweight identifier for debugging without changing functionality.
| # Process batch | ||
| _LOGGER.debug("Processing batch of %d request(s)", len(batch)) | ||
| await self._process_batch(batch) | ||
|
|
||
| except Exception as e: | ||
| _LOGGER.exception("Error in inference loop: %s", e) |
There was a problem hiding this comment.
Per CLAUDE.md async logging guidance, the Error in inference loop exception log never includes id(self) (or similar client identifier), so failed batches cannot be traced back to a particular client instance; can we include the client id in this log?
| # Process batch | |
| _LOGGER.debug("Processing batch of %d request(s)", len(batch)) | |
| await self._process_batch(batch) | |
| except Exception as e: | |
| _LOGGER.exception("Error in inference loop: %s", e) | |
| # Process batch | |
| _LOGGER.debug("Processing batch of %d request(s)", len(batch)) | |
| await self._process_batch(batch) | |
| except Exception as e: | |
| _LOGGER.exception("Error in inference loop for client %s: %s", hex(id(self)), e) |
Finding type: AI Coding Guidelines
There was a problem hiding this comment.
Not necessary IMO
There was a problem hiding this comment.
I hear you, but tracing the background task failures to a specific client instance is hard without an identifier—the exception already includes the client’s ID now, so we can correlate Error in inference loop for client … messages with the right request queue.
|
|
||
| except Exception as e: | ||
| _LOGGER.exception("Error processing batch: %s", e) |
There was a problem hiding this comment.
Per CLAUDE.md async logging guidance, the Error processing batch exception log never includes id(self) (or similar client identifier), so downstream errors cannot be tied back to a particular client instance when multiple are running; can we include the client id here as well?
| except Exception as e: | |
| _LOGGER.exception("Error processing batch: %s", e) | |
| except Exception as e: | |
| _LOGGER.exception("Error processing batch in client id=%s: %s", id(self), e) |
Finding type: AI Coding Guidelines
| # Create minimal tokenizer | ||
| minimal_tokenizer = transformers.GPT2Tokenizer.from_pretrained( | ||
| "gpt2", # Uses cached tokenizer vocab | ||
| model_max_length=32, | ||
| ) | ||
| minimal_tokenizer.pad_token = minimal_tokenizer.eos_token | ||
| minimal_tokenizer.padding_side = "left" | ||
| # Add minimal chat template | ||
| minimal_tokenizer.chat_template = "{% for message in messages %}{{ message.content }}{% endfor %}" |
There was a problem hiding this comment.
[Testing] The “no-download” integration test still calls transformers.GPT2Tokenizer.from_pretrained("gpt2") (lines 422‑430), which hits the HuggingFace hub whenever the vocab isn’t already cached on the test machine. That reintroduces a network dependency into the unit test suite and will make CI fail in air‑gapped environments. Please construct the tokenizer fully in-code (e.g., build a tiny tokenizers.Tokenizer/PreTrainedTokenizerFast with a hard-coded vocab, or write minimal vocab/merges files to a temporary directory and load from that path) so the test is deterministic and offline.
Context for Agents
The “no-download” integration test still calls `transformers.GPT2Tokenizer.from_pretrained("gpt2")` (lines 422‑430), which hits the HuggingFace hub whenever the vocab isn’t already cached on the test machine. That reintroduces a network dependency into the unit test suite and will make CI fail in air‑gapped environments. Please construct the tokenizer fully in-code (e.g., build a tiny `tokenizers.Tokenizer`/`PreTrainedTokenizerFast` with a hard-coded vocab, or write minimal vocab/merges files to a temporary directory and load from that path) so the test is deterministic and offline.
File: src/ares/contrib/transformers_client_test.py
Line: 430| @functools.cached_property | ||
| def _inference_task(self) -> asyncio.Task[None]: | ||
| """Lazy-initialized background inference task. | ||
|
|
||
| The task automatically exits when the client is garbage collected via weakref. | ||
| """ | ||
| weak_self = weakref.ref(self) | ||
| task = asyncio.create_task(self._inference_loop(weak_self)) | ||
| _LOGGER.info("TransformersLLMClient started background inference task") | ||
| return task | ||
|
|
||
| @functools.cached_property | ||
| def _device(self) -> str: | ||
| """Resolved device.""" |
There was a problem hiding this comment.
TransformersLLMClient logs TransformersLLMClient started background inference task without id(self) even though this async background task can run concurrently for multiple clients, so tracing is ambiguous; CLAUDE.md requires async logging to include object identifiers—can we log the client id in these statements?
Finding type: AI Coding Guidelines
Prompt for AI Agents:
In src/ares/contrib/transformers_client.py around lines 154 to 167, the _inference_task
cached_property logs "TransformersLLMClient started background inference task" without
identifying which client started the task. Update the logging call in that method to
include a unique identifier for the client (for example id(self)) and optionally the
model_name so concurrent clients are distinguishable. Make the log message something
like: "TransformersLLMClient started background inference task (id=%s, model=%s)" and
pass id(self) and self.model_name as parameters to the logger.
| @pytest.mark.asyncio | ||
| async def test_integration_with_minimal_model(): | ||
| """Integration test with a minimal GPT2 model. | ||
|
|
||
| Creates a tiny GPT2 model from scratch for testing the full pipeline. | ||
| Note: Downloads GPT2 tokenizer vocab on first run (~500KB, cached after). | ||
| """ | ||
| # Create minimal GPT2 config - vocab_size must match GPT2Tokenizer (50257) | ||
| config = transformers.GPT2Config( | ||
| vocab_size=50257, # Must match GPT2Tokenizer vocab | ||
| n_positions=32, | ||
| n_ctx=32, | ||
| n_embd=32, | ||
| n_layer=2, | ||
| n_head=4, | ||
| ) | ||
|
|
||
| minimal_model = transformers.GPT2LMHeadModel(config) | ||
| minimal_model.eval() | ||
|
|
||
| # GPT2 tokenizer is lightweight and cached after first download | ||
| minimal_tokenizer = transformers.GPT2Tokenizer.from_pretrained( | ||
| "gpt2", | ||
| model_max_length=32, | ||
| ) |
There was a problem hiding this comment.
test_integration_with_minimal_model calls transformers.GPT2Tokenizer.from_pretrained("gpt2"), which performs a real network download inside src/ares/contrib/transformers_client_test.py; CLAUDE.md mandates that unit tests under src/ must mock external services and real integration tests belong under integration_tests/, so this test should either mock the tokenizer download or move to the integration test suite (CLAUDE.md).
Finding type: AI Coding Guidelines
Prompt for AI Agents:
In src/ares/contrib/transformers_client_test.py around lines 286 to 310, the
test_integration_with_minimal_model function calls
transformers.GPT2Tokenizer.from_pretrained("gpt2") which performs a real network
download (not allowed for unit tests under src/). Refactor by mocking that call: replace
the direct from_pretrained invocation with a context manager that patches
transformers.GPT2Tokenizer.from_pretrained to return the locally constructed
minimal_tokenizer (use mock.patch.object(transformers.GPT2Tokenizer, "from_pretrained",
return_value=minimal_tokenizer)) so no network I/O occurs; alternatively, if you intend
this to be a true integration test, move the whole test function into integration_tests/
and update imports accordingly.
| self._generate_batch, | ||
| chat_conversations, | ||
| temperature=temperature, | ||
| max_new_tokens=max_new_tokens, | ||
| ) |
There was a problem hiding this comment.
_generate_batch is only called with chat_conversations, temperature, and max_new_tokens, so any other LLMRequest.to_chat_completion_kwargs() fields (e.g. top_p, stop/stop_sequences, tools, etc.) never reach _model.generate; those request options now have no effect when using this client
Finding type: Logical Bugs
Prompt for AI Agents:
In src/ares/contrib/transformers_client.py around lines 279 to 283, the call to
self._generate_batch only forwards (chat_conversations, temperature, max_new_tokens) so
other per-request options returned by to_chat_completion_kwargs (e.g., top_p,
stop/stop_sequences, any decoding options or tool hints) are lost. Refactor so that when
grouping requests you capture and propagate the full generation kwargs for the group
(use the first request's kwargs as the group's canonical kwargs, and only group requests
whose kwargs are identical or compatible), then change _process_batch/_generate_batch
signatures to accept a dict of generation kwargs and pass those through into the
tokenizer/model calls (model.generate and any tokenizer settings) instead of only
passing temperature and max_new_tokens. Ensure to update the grouping logic and tests
accordingly so stop sequences, top_p, etc. are applied per-request group.
There was a problem hiding this comment.
Are we ok with this?
| outputs = self._model.generate( | ||
| **inputs, # type: ignore[arg-type] | ||
| max_new_tokens=max_new_tokens, | ||
| temperature=temperature, | ||
| do_sample=temperature > 0, | ||
| pad_token_id=self._tokenizer.pad_token_id, | ||
| ) | ||
|
|
There was a problem hiding this comment.
self._model.generate can return a ModelOutput when config.return_dict_in_generate=True (e.g. some HF configs), but we immediately do generated_ids = outputs[:, input_lengths:] as if a tensor, so the call will fail (no __getitem__) for those models; can we either pass return_dict_in_generate=False to generate or pull outputs.sequences before slicing?
| outputs = self._model.generate( | |
| **inputs, # type: ignore[arg-type] | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| pad_token_id=self._tokenizer.pad_token_id, | |
| ) | |
| outputs = self._model.generate( | |
| **inputs, # type: ignore[arg-type] | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| pad_token_id=self._tokenizer.pad_token_id, | |
| return_dict_in_generate=False, | |
| ) |
Finding type: Logical Bugs
rsmith49
left a comment
There was a problem hiding this comment.
This is cool @joshgreaves!
A couple comments, only one I feel pretty strongly about is truncation since that seems to be really important for local models, given how quickly context length explodes in coding problems.
| await self._request_queue.put(ValueAndFuture(value=req, future=future)) | ||
| return await future | ||
|
|
||
| async def _inference_loop(self, weak_self: weakref.ReferenceType) -> None: |
There was a problem hiding this comment.
IMO the performance gains by doing this batch aggregation loop (for people running LLMs locally) are not going to outweight the debugging cost for inference issues compared to just running all single inference. If we want to leave it as a cool feature that's fine. But it is going to be a nightmare to debug inference errors.
That being said, it is a cool implementation and fun to see
There was a problem hiding this comment.
This is a good point, it might make sense to have a separate client for this.
It makes a pretty massive difference speed-wise doing this, so at least for this implementation I'd like to keep it in.
There was a problem hiding this comment.
Commit a71bcfc addressed this comment by improving failure behavior in the batching inference loop: it tracks collected batch items and, upon any exception, logs the error and propagates the exception to all pending requests’ futures so callers don’t hang and inference errors are easier to debug.
| pass | ||
|
|
||
|
|
||
| def _detect_device() -> str: |
There was a problem hiding this comment.
The torch logic here is nice, noting here so I remember to move it to a shared contrib/utils or something that can be used by transformer-lens as well (and any other local inference things)
| input_texts, | ||
| return_tensors="pt", | ||
| padding=True, | ||
| truncation=True, |
There was a problem hiding this comment.
I've found truncation is actually the most complicated part of local inference - for instance, truncation=true will fully take out the <|im_end|> or </assistant> or whatever tags exist that indicate the LLM should respond to the user, instead of completing the user turn.
I think we should make truncation_strategy a first class implementation (probably in a contrib/utils module so other local inference methods can utilize it too). But at the minimum, we should include some kind of instance method or init param that specifies how truncation is done. For now something like the below would work
@dataclasses.dataclass(frozen=True, kw_only=True)
class TransformersLLMClient(llm_clients.LLMClient):
...
truncation_strategy: str | Callable[[str], str] = "auto"
def _generate_batch(
...
if callable(self.truncation_strategy):
input_texts = [self.truncation_strategy(text) for text in input_texts]
hf_truncation = False
elif self.truncation_strategy == "auto":
hf_truncation = True
inputs: transformers.BatchEncoding = self._tokenizer(
input_texts,
return_tensors="pt",
padding=True,
truncation=True,
).to(self._device)
mwilliammyers
left a comment
There was a problem hiding this comment.
LGTM! Thanks Josh!
| self._generate_batch, | ||
| chat_conversations, | ||
| temperature=temperature, | ||
| max_new_tokens=max_new_tokens, | ||
| ) |
There was a problem hiding this comment.
Are we ok with this?
Generated description
Below is a concise technical summary of the changes proposed in this PR:
graph LR TransformersLLMClient_call_("TransformersLLMClient.__call__"):::added TransformersLLMClient_inference_loop_("TransformersLLMClient._inference_loop"):::added TransformersLLMClient_process_batch_("TransformersLLMClient._process_batch"):::added TransformersLLMClient_generate_batch_("TransformersLLMClient._generate_batch"):::added TransformersLLMClient_tokenizer_("TransformersLLMClient._tokenizer"):::added TransformersLLMClient_model_("TransformersLLMClient._model"):::added TRANSFORMERS_("TRANSFORMERS"):::added TORCH_("TORCH"):::added TransformersLLMClient_call_ -- "Queues requests, starts background task, returns future for response" --> TransformersLLMClient_inference_loop_ TransformersLLMClient_inference_loop_ -- "Batches grouped requests by params and invokes processing" --> TransformersLLMClient_process_batch_ TransformersLLMClient_process_batch_ -- "Converts requests to conversations and runs generation in thread" --> TransformersLLMClient_generate_batch_ TransformersLLMClient_generate_batch_ -- "Tokenizes chat messages and decodes generated token IDs" --> TransformersLLMClient_tokenizer_ TransformersLLMClient_generate_batch_ -- "Calls model.generate with batched inputs to produce outputs" --> TransformersLLMClient_model_ TransformersLLMClient_model_ -- "Loads AutoModelForCausalLM from transformers with dtype/device" --> TRANSFORMERS_ TransformersLLMClient_tokenizer_ -- "Loads AutoTokenizer, sets padding and pad/eos tokens" --> TRANSFORMERS_ TransformersLLMClient_generate_batch_ -- "Moves tensors to device, generates under no_grad, computes usage" --> TORCH_ TransformersLLMClient_model_ -- "Resolves dtype, moves model to device, sets eval mode" --> TORCH_ classDef added stroke:#15AA7A classDef removed stroke:#CD5270 classDef modified stroke:#EDAC4C linkStyle default stroke:#CBD5E1,font-size:13pxIntroduces a
TransformersLLMClientto enable local model inference with automatic request batching and device-specific optimizations. Updates dependency groups and CI workflows to support the new client and ensure consistent environment synchronization.llamacppoptional dependency tollamacpp_clientfor naming consistency across contrib modules.Modified files (4)
Latest Contributors(2)
TransformersLLMClientwhich provides an asynchronous interface for HuggingFace models, featuring automatic batching of concurrent requests and intelligent device/dtype selection.Modified files (3)
Latest Contributors(2)