Skip to content

Add a transformers client in contrib.#89

Merged
mwilliammyers merged 15 commits intomainfrom
transformers-client
Feb 24, 2026
Merged

Add a transformers client in contrib.#89
mwilliammyers merged 15 commits intomainfrom
transformers-client

Conversation

@joshgreaves
Copy link
Copy Markdown
Contributor

@joshgreaves joshgreaves commented Feb 11, 2026

Generated description

Below is a concise technical summary of the changes proposed in this PR:

graph LR
TransformersLLMClient_call_("TransformersLLMClient.__call__"):::added
TransformersLLMClient_inference_loop_("TransformersLLMClient._inference_loop"):::added
TransformersLLMClient_process_batch_("TransformersLLMClient._process_batch"):::added
TransformersLLMClient_generate_batch_("TransformersLLMClient._generate_batch"):::added
TransformersLLMClient_tokenizer_("TransformersLLMClient._tokenizer"):::added
TransformersLLMClient_model_("TransformersLLMClient._model"):::added
TRANSFORMERS_("TRANSFORMERS"):::added
TORCH_("TORCH"):::added
TransformersLLMClient_call_ -- "Queues requests, starts background task, returns future for response" --> TransformersLLMClient_inference_loop_
TransformersLLMClient_inference_loop_ -- "Batches grouped requests by params and invokes processing" --> TransformersLLMClient_process_batch_
TransformersLLMClient_process_batch_ -- "Converts requests to conversations and runs generation in thread" --> TransformersLLMClient_generate_batch_
TransformersLLMClient_generate_batch_ -- "Tokenizes chat messages and decodes generated token IDs" --> TransformersLLMClient_tokenizer_
TransformersLLMClient_generate_batch_ -- "Calls model.generate with batched inputs to produce outputs" --> TransformersLLMClient_model_
TransformersLLMClient_model_ -- "Loads AutoModelForCausalLM from transformers with dtype/device" --> TRANSFORMERS_
TransformersLLMClient_tokenizer_ -- "Loads AutoTokenizer, sets padding and pad/eos tokens" --> TRANSFORMERS_
TransformersLLMClient_generate_batch_ -- "Moves tensors to device, generates under no_grad, computes usage" --> TORCH_
TransformersLLMClient_model_ -- "Resolves dtype, moves model to device, sets eval mode" --> TORCH_
classDef added stroke:#15AA7A
classDef removed stroke:#CD5270
classDef modified stroke:#EDAC4C
linkStyle default stroke:#CBD5E1,font-size:13px
Loading

Introduces a TransformersLLMClient to enable local model inference with automatic request batching and device-specific optimizations. Updates dependency groups and CI workflows to support the new client and ensure consistent environment synchronization.

TopicDetails
Infrastructure Updates GitHub Actions workflows to sync all dependency groups and renames the llamacpp optional dependency to llamacpp_client for naming consistency across contrib modules.
Modified files (4)
  • .github/workflows/ruff.yml
  • .github/workflows/unit-tests.yml
  • pyproject.toml
  • src/ares/contrib/llama_cpp.py
Latest Contributors(2)
UserCommitDate
Narmeen07Add-mechanistic-interp...February 19, 2026
joshua.greaves@gmail.comBump-ARES-to-0.0.2-72January 29, 2026
Transformers Client Implements TransformersLLMClient which provides an asynchronous interface for HuggingFace models, featuring automatic batching of concurrent requests and intelligent device/dtype selection.
Modified files (3)
  • pyproject.toml
  • src/ares/contrib/transformers_client.py
  • src/ares/contrib/transformers_client_test.py
Latest Contributors(2)
UserCommitDate
Narmeen07Add-mechanistic-interp...February 19, 2026
joshua.greaves@gmail.comBump-ARES-to-0.0.2-72January 29, 2026
This pull request is reviewed by Baz. Review like a pro on (Baz).

@joshgreaves joshgreaves requested a review from rsmith49 February 11, 2026 23:29
Comment on lines +201 to +207
def _ensure_inference_task_started(self) -> None:
"""Lazy-start the background inference task if not already running."""
if self._inference_task is None or self._inference_task.done():
# Create weakref to self so task can detect when client is GC'd
weak_self = weakref.ref(self)
self._inference_task = asyncio.create_task(self._inference_loop(weak_self))
_LOGGER.info("TransformersLLMClient started background inference task")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per CLAUDE.md async logging guidance, TransformersLLMClient started background inference task never logs id(self) or another client identifier, so if multiple clients run concurrently you cannot trace which instance started its background task; can we include the client id (e.g., id(self)) in this log?

Suggested change
def _ensure_inference_task_started(self) -> None:
"""Lazy-start the background inference task if not already running."""
if self._inference_task is None or self._inference_task.done():
# Create weakref to self so task can detect when client is GC'd
weak_self = weakref.ref(self)
self._inference_task = asyncio.create_task(self._inference_loop(weak_self))
_LOGGER.info("TransformersLLMClient started background inference task")
def _ensure_inference_task_started(self) -> None:
"""Lazy-start the background inference task if not already running."""
if self._inference_task is None or self._inference_task.done():
# Create weakref to self so task can detect when client is GC'd
weak_self = weakref.ref(self)
self._inference_task = asyncio.create_task(self._inference_loop(weak_self))
_LOGGER.info("TransformersLLMClient(id=%s) started background inference task", id(self))

Finding type: AI Coding Guidelines

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary, IMO

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log currently just says TransformersLLMClient started background inference task, so when multiple clients are running there’s no way to tell which instance triggered it; including id(self) (or similar) gives us a lightweight identifier for debugging without changing functionality.

Comment on lines +266 to +271
# Process batch
_LOGGER.debug("Processing batch of %d request(s)", len(batch))
await self._process_batch(batch)

except Exception as e:
_LOGGER.exception("Error in inference loop: %s", e)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per CLAUDE.md async logging guidance, the Error in inference loop exception log never includes id(self) (or similar client identifier), so failed batches cannot be traced back to a particular client instance; can we include the client id in this log?

Suggested change
# Process batch
_LOGGER.debug("Processing batch of %d request(s)", len(batch))
await self._process_batch(batch)
except Exception as e:
_LOGGER.exception("Error in inference loop: %s", e)
# Process batch
_LOGGER.debug("Processing batch of %d request(s)", len(batch))
await self._process_batch(batch)
except Exception as e:
_LOGGER.exception("Error in inference loop for client %s: %s", hex(id(self)), e)

Finding type: AI Coding Guidelines

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary IMO

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hear you, but tracing the background task failures to a specific client instance is hard without an identifier—the exception already includes the client’s ID now, so we can correlate Error in inference loop for client … messages with the right request queue.

Comment on lines +314 to +316

except Exception as e:
_LOGGER.exception("Error processing batch: %s", e)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per CLAUDE.md async logging guidance, the Error processing batch exception log never includes id(self) (or similar client identifier), so downstream errors cannot be tied back to a particular client instance when multiple are running; can we include the client id here as well?

Suggested change
except Exception as e:
_LOGGER.exception("Error processing batch: %s", e)
except Exception as e:
_LOGGER.exception("Error processing batch in client id=%s: %s", id(self), e)

Finding type: AI Coding Guidelines

Comment on lines +422 to +430
# Create minimal tokenizer
minimal_tokenizer = transformers.GPT2Tokenizer.from_pretrained(
"gpt2", # Uses cached tokenizer vocab
model_max_length=32,
)
minimal_tokenizer.pad_token = minimal_tokenizer.eos_token
minimal_tokenizer.padding_side = "left"
# Add minimal chat template
minimal_tokenizer.chat_template = "{% for message in messages %}{{ message.content }}{% endfor %}"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important

[Testing] The “no-download” integration test still calls transformers.GPT2Tokenizer.from_pretrained("gpt2") (lines 422‑430), which hits the HuggingFace hub whenever the vocab isn’t already cached on the test machine. That reintroduces a network dependency into the unit test suite and will make CI fail in air‑gapped environments. Please construct the tokenizer fully in-code (e.g., build a tiny tokenizers.Tokenizer/PreTrainedTokenizerFast with a hard-coded vocab, or write minimal vocab/merges files to a temporary directory and load from that path) so the test is deterministic and offline.

Context for Agents
The “no-download” integration test still calls `transformers.GPT2Tokenizer.from_pretrained("gpt2")` (lines 422‑430), which hits the HuggingFace hub whenever the vocab isn’t already cached on the test machine. That reintroduces a network dependency into the unit test suite and will make CI fail in air‑gapped environments. Please construct the tokenizer fully in-code (e.g., build a tiny `tokenizers.Tokenizer`/`PreTrainedTokenizerFast` with a hard-coded vocab, or write minimal vocab/merges files to a temporary directory and load from that path) so the test is deterministic and offline.

File: src/ares/contrib/transformers_client_test.py
Line: 430

Comment on lines +154 to +167
@functools.cached_property
def _inference_task(self) -> asyncio.Task[None]:
"""Lazy-initialized background inference task.

The task automatically exits when the client is garbage collected via weakref.
"""
weak_self = weakref.ref(self)
task = asyncio.create_task(self._inference_loop(weak_self))
_LOGGER.info("TransformersLLMClient started background inference task")
return task

@functools.cached_property
def _device(self) -> str:
"""Resolved device."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransformersLLMClient logs TransformersLLMClient started background inference task without id(self) even though this async background task can run concurrently for multiple clients, so tracing is ambiguous; CLAUDE.md requires async logging to include object identifiers—can we log the client id in these statements?

Finding type: AI Coding Guidelines

Prompt for AI Agents:

In src/ares/contrib/transformers_client.py around lines 154 to 167, the _inference_task
cached_property logs "TransformersLLMClient started background inference task" without
identifying which client started the task. Update the logging call in that method to
include a unique identifier for the client (for example id(self)) and optionally the
model_name so concurrent clients are distinguishable. Make the log message something
like: "TransformersLLMClient started background inference task (id=%s, model=%s)" and
pass id(self) and self.model_name as parameters to the logger.

Fix in Cursor

Comment on lines +286 to +310
@pytest.mark.asyncio
async def test_integration_with_minimal_model():
"""Integration test with a minimal GPT2 model.

Creates a tiny GPT2 model from scratch for testing the full pipeline.
Note: Downloads GPT2 tokenizer vocab on first run (~500KB, cached after).
"""
# Create minimal GPT2 config - vocab_size must match GPT2Tokenizer (50257)
config = transformers.GPT2Config(
vocab_size=50257, # Must match GPT2Tokenizer vocab
n_positions=32,
n_ctx=32,
n_embd=32,
n_layer=2,
n_head=4,
)

minimal_model = transformers.GPT2LMHeadModel(config)
minimal_model.eval()

# GPT2 tokenizer is lightweight and cached after first download
minimal_tokenizer = transformers.GPT2Tokenizer.from_pretrained(
"gpt2",
model_max_length=32,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_integration_with_minimal_model calls transformers.GPT2Tokenizer.from_pretrained("gpt2"), which performs a real network download inside src/ares/contrib/transformers_client_test.py; CLAUDE.md mandates that unit tests under src/ must mock external services and real integration tests belong under integration_tests/, so this test should either mock the tokenizer download or move to the integration test suite (CLAUDE.md).

Finding type: AI Coding Guidelines

Prompt for AI Agents:

In src/ares/contrib/transformers_client_test.py around lines 286 to 310, the
test_integration_with_minimal_model function calls
transformers.GPT2Tokenizer.from_pretrained("gpt2") which performs a real network
download (not allowed for unit tests under src/). Refactor by mocking that call: replace
the direct from_pretrained invocation with a context manager that patches
transformers.GPT2Tokenizer.from_pretrained to return the locally constructed
minimal_tokenizer (use mock.patch.object(transformers.GPT2Tokenizer, "from_pretrained",
return_value=minimal_tokenizer)) so no network I/O occurs; alternatively, if you intend
this to be a true integration test, move the whole test function into integration_tests/
and update imports accordingly.

Fix in Cursor

Comment on lines +279 to +283
self._generate_batch,
chat_conversations,
temperature=temperature,
max_new_tokens=max_new_tokens,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_generate_batch is only called with chat_conversations, temperature, and max_new_tokens, so any other LLMRequest.to_chat_completion_kwargs() fields (e.g. top_p, stop/stop_sequences, tools, etc.) never reach _model.generate; those request options now have no effect when using this client

Finding type: Logical Bugs

Prompt for AI Agents:

In src/ares/contrib/transformers_client.py around lines 279 to 283, the call to
self._generate_batch only forwards (chat_conversations, temperature, max_new_tokens) so
other per-request options returned by to_chat_completion_kwargs (e.g., top_p,
stop/stop_sequences, any decoding options or tool hints) are lost. Refactor so that when
grouping requests you capture and propagate the full generation kwargs for the group
(use the first request's kwargs as the group's canonical kwargs, and only group requests
whose kwargs are identical or compatible), then change _process_batch/_generate_batch
signatures to accept a dict of generation kwargs and pass those through into the
tokenizer/model calls (model.generate and any tokenizer settings) instead of only
passing temperature and max_new_tokens. Ensure to update the grouping logic and tests
accordingly so stop sequences, top_p, etc. are applied per-request group.

Fix in Cursor

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we ok with this?

Comment on lines +330 to +337
outputs = self._model.generate(
**inputs, # type: ignore[arg-type]
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self._tokenizer.pad_token_id,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._model.generate can return a ModelOutput when config.return_dict_in_generate=True (e.g. some HF configs), but we immediately do generated_ids = outputs[:, input_lengths:] as if a tensor, so the call will fail (no __getitem__) for those models; can we either pass return_dict_in_generate=False to generate or pull outputs.sequences before slicing?

Suggested change
outputs = self._model.generate(
**inputs, # type: ignore[arg-type]
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self._tokenizer.pad_token_id,
)
outputs = self._model.generate(
**inputs, # type: ignore[arg-type]
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self._tokenizer.pad_token_id,
return_dict_in_generate=False,
)

Finding type: Logical Bugs

rsmith49
rsmith49 previously approved these changes Feb 13, 2026
Copy link
Copy Markdown
Contributor

@rsmith49 rsmith49 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool @joshgreaves!

A couple comments, only one I feel pretty strongly about is truncation since that seems to be really important for local models, given how quickly context length explodes in coding problems.

await self._request_queue.put(ValueAndFuture(value=req, future=future))
return await future

async def _inference_loop(self, weak_self: weakref.ReferenceType) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the performance gains by doing this batch aggregation loop (for people running LLMs locally) are not going to outweight the debugging cost for inference issues compared to just running all single inference. If we want to leave it as a cool feature that's fine. But it is going to be a nightmare to debug inference errors.

That being said, it is a cool implementation and fun to see

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, it might make sense to have a separate client for this.
It makes a pretty massive difference speed-wise doing this, so at least for this implementation I'd like to keep it in.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commit a71bcfc addressed this comment by improving failure behavior in the batching inference loop: it tracks collected batch items and, upon any exception, logs the error and propagates the exception to all pending requests’ futures so callers don’t hang and inference errors are easier to debug.

pass


def _detect_device() -> str:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch logic here is nice, noting here so I remember to move it to a shared contrib/utils or something that can be used by transformer-lens as well (and any other local inference things)

input_texts,
return_tensors="pt",
padding=True,
truncation=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've found truncation is actually the most complicated part of local inference - for instance, truncation=true will fully take out the <|im_end|> or </assistant> or whatever tags exist that indicate the LLM should respond to the user, instead of completing the user turn.

I think we should make truncation_strategy a first class implementation (probably in a contrib/utils module so other local inference methods can utilize it too). But at the minimum, we should include some kind of instance method or init param that specifies how truncation is done. For now something like the below would work

@dataclasses.dataclass(frozen=True, kw_only=True)
class TransformersLLMClient(llm_clients.LLMClient):
    ...
    truncation_strategy: str | Callable[[str], str] = "auto"

    def _generate_batch(
        ...
        if callable(self.truncation_strategy):
            input_texts = [self.truncation_strategy(text) for text in input_texts]
            hf_truncation = False
        elif self.truncation_strategy == "auto":
            hf_truncation = True

        inputs: transformers.BatchEncoding = self._tokenizer(
            input_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
        ).to(self._device)    

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a TODO

Copy link
Copy Markdown
Collaborator

@mwilliammyers mwilliammyers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks Josh!

Comment on lines +279 to +283
self._generate_batch,
chat_conversations,
temperature=temperature,
max_new_tokens=max_new_tokens,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we ok with this?

@mwilliammyers mwilliammyers merged commit c804aa2 into main Feb 24, 2026
4 checks passed
@mwilliammyers mwilliammyers deleted the transformers-client branch February 24, 2026 23:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants