Skip to content

Add mechanistic interpretability toolkit and 20Q case study#91

Merged
rsmith49 merged 29 commits intomainfrom
narmeen/mi-q20
Feb 19, 2026
Merged

Add mechanistic interpretability toolkit and 20Q case study#91
rsmith49 merged 29 commits intomainfrom
narmeen/mi-q20

Conversation

@Narmeen07
Copy link
Copy Markdown
Contributor

Summary

  • Mechanistic interpretability contrib module (src/ares/contrib/mech_interp/): TransformerLens integration for ARES environments, including HookedTransformerLLMClient, ActivationCapture for trajectory-level caching, and hook utilities (zero ablation, mean ablation, path patching, InterventionManager)
  • Twenty Questions environment (src/ares/environments/twenty_questions.py): A lightweight env where an LLM plays 20 Questions against an oracle that flags invalid (non-yes/no) questions — designed as a testbed for studying rule-following behaviour
  • End-to-end 20Q case study (examples/20q_case_study/): Data collection, linear probing (87% accuracy predicting invalid questions from residual stream), per-step Contrastive Activation Addition (CAA) steering (69% → 58% invalid rate at α=4.0), steering vector evolution visualisation, and a tutorial notebook summarising all experiments

Test plan

  • Verify uv run pytest passes (no new test files — case study is experimental scripts)
  • Review mech_interp module API (ActivationCapture, HookedTransformerLLMClient, hook utilities)
  • Run tutorial notebook on a GPU machine to verify end-to-end flow
  • Check that Twenty Questions environment works via ares.make("20q")

rsmith49 and others added 14 commits January 30, 2026 11:30
- Add TwentyQuestionsEnvironment: a lightweight, container-free environment
  for the 20 Questions game with an LLM oracle
- Register "20q" preset in the registry
- Add mech_interp_20q_multimodel_probing example for training linear probes
  on mid-layer residual streams
- Add scikit-learn and matplotlib to transformer-lens dependency group
- Update CLAUDE.md with public API docs, proxy section, and CI/CD details

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add probe_invalid_question.py: trains linear probes on mid-layer
  residual stream activations to predict oracle "Invalid Question"
  responses per step
- Add inspect_20q_generations.py: utility to inspect 20Q episode
  generations
- Update oracle prompt in twenty_questions.py to be more lenient
  with question extraction from player messages containing reasoning
- Move collect_20q_data.py and probe_invalid_question.py into
  examples/20q_case_study/, renaming probe to phase1_probe.py
- Add phase2_steer.py: computes contrastive activation steering
  vectors from Phase 1 cached activations, then runs new episodes
  with steering at the target step across multiple alpha values
- Multi-GPU support via ThreadPoolExecutor work queue pattern
- Per-step logging and grouped bar chart visualization
…ture

- Capture all forward-pass activations during generate() and concatenate
  along seq dim to support prompt-aware pooling strategies
- Add per-step steering vectors instead of single target step
- Add system_prompt support to TwentyQuestionsEnvironment and presets
- Prepend system_prompt in HookedTransformerLLMClient chat formatting
- Improve target step selection using probe accuracy * invalid rate
- Add inspect mode for sequential single-GPU readable output
- Add analyse_steering.py for post-hoc analysis of steering results
- Pre-compute deterministic secret words for cross-condition reproducibility
- Add end-to-end Jupyter tutorial (ares_mi_20q_tutorial.ipynb) covering
  data collection, probing, steering vector analysis, and intervention
- Add steering vector evolution visualisation script (cosine similarity
  heatmap, norm per step, PCA trajectory)
- Update analyse_steering.py with paper-ready tables and qualitative examples
- Fix system prompt handling in phase2_steer.py
- Remove deprecated standalone probing/generation scripts (consolidated
  into 20q_case_study/)
- Add matplotlib dependency to pyproject.toml
Comment on lines +236 to +267
target_step: int,
pooling: str,
min_class_samples: int,
) -> np.ndarray:
"""Compute steering vector from training activations at *target_step*.

Returns v_valid - v_invalid (shape [d_model]).
"""
valid_features: list[np.ndarray] = []
invalid_features: list[np.ndarray] = []

for ep in episodes:
if ep["episode_idx"] not in train_eps:
continue
for step in ep["steps"]:
if step["step_idx"] != target_step:
continue
activation = step.get("activation")
if activation is None:
continue
prompt_len = step.get("prompt_len", 0)
feature = _pool_activation(activation, pooling, prompt_len=prompt_len)
if step["is_invalid"]:
invalid_features.append(feature)
else:
valid_features.append(feature)

n_valid = len(valid_features)
n_invalid = len(invalid_features)
print(f" [{pooling:6s}] step {target_step}: {n_valid} valid, {n_invalid} invalid (train)")

if n_valid < min_class_samples or n_invalid < min_class_samples:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important

[Reliability] ### 3. Resource Leak: Model Not Cleaned Up on Exception
File: examples/20q_case_study/phase2_steer.py
Lines: 236-267

async def _run_steered_episodes_for_device(...):
    with _MODEL_LOAD_LOCK:
        model = HookedTransformer.from_pretrained(MODEL_NAME, device="cpu", dtype=torch.bfloat16)
    model = model.to(device)
    # ... lots of work ...
    finally:
        hook_point.remove_hooks("fwd")

    del model, tokenizer, client
    _free_gpu_memory()

Issue: The finally block only removes hooks. If an exception occurs after model loading but before reaching the del statements at the end, the model remains in memory. The del statements should be inside the finally block to guarantee cleanup.

Fix:

async def _run_steered_episodes_for_device(...):
    model = None
    tokenizer = None
    client = None
    try:
        with _MODEL_LOAD_LOCK:
            model = HookedTransformer.from_pretrained(MODEL_NAME, device="cpu", dtype=torch.bfloat16)
        model = model.to(device)
        # ... work ...
    finally:
        if model is not None:
            hook_point.remove_hooks("fwd")
            del model
        if tokenizer is not None:
            del tokenizer
        if client is not None:
            del client
        _free_gpu_memory()
Context for Agents
### 3. **Resource Leak: Model Not Cleaned Up on Exception**
**File:** `examples/20q_case_study/phase2_steer.py`
**Lines:** 236-267

```python
async def _run_steered_episodes_for_device(...):
    with _MODEL_LOAD_LOCK:
        model = HookedTransformer.from_pretrained(MODEL_NAME, device="cpu", dtype=torch.bfloat16)
    model = model.to(device)
    # ... lots of work ...
    finally:
        hook_point.remove_hooks("fwd")

    del model, tokenizer, client
    _free_gpu_memory()
```

**Issue:** The `finally` block only removes hooks. If an exception occurs after model loading but before reaching the `del` statements at the end, the model remains in memory. The `del` statements should be inside the `finally` block to guarantee cleanup.

**Fix:**
```python
async def _run_steered_episodes_for_device(...):
    model = None
    tokenizer = None
    client = None
    try:
        with _MODEL_LOAD_LOCK:
            model = HookedTransformer.from_pretrained(MODEL_NAME, device="cpu", dtype=torch.bfloat16)
        model = model.to(device)
        # ... work ...
    finally:
        if model is not None:
            hook_point.remove_hooks("fwd")
            del model
        if tokenizer is not None:
            del tokenizer
        if client is not None:
            del client
        _free_gpu_memory()
```

File: examples/20q_case_study/phase2_steer.py
Line: 267

Comment on lines +92 to +111
) -> torch.Tensor:
# self.model.reset_hooks(direction="fwd", including_permanent=False)

if self.fwd_hooks is not None:
for name_or_id_fn, hook_fn in self.fwd_hooks:
# Check if this is a StateIdFn, or if it is the standard format to run in all states
if inspect.iscoroutinefunction(name_or_id_fn):
name = await name_or_id_fn(state)
else:
name = name_or_id_fn

self.model.add_hook(name, hook_fn, is_permanent=False) # type: ignore

with torch.no_grad():
# TODO: Should we make use of the `__call__(..., return_type="logits | loss")` here instead?
# Generate completion
# Note: HookedTransformer.generate returns full sequence including input
outputs = self.model.generate(
input_ids,
**gen_kwargs,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important

[Logic] ### 7. Inconsistent Hook Management Pattern
File: src/ares/contrib/mech_interp/hooked_transformer_client.py
Lines: 92-111

async def _call_with_hooks(self, input_ids, state, **gen_kwargs):
    if self.fwd_hooks is not None:
        for name_or_id_fn, hook_fn in self.fwd_hooks:
            if inspect.iscoroutinefunction(name_or_id_fn):
                name = await name_or_id_fn(state)
            else:
                name = name_or_id_fn
            self.model.add_hook(name, hook_fn, is_permanent=False)
    
    with torch.no_grad():
        outputs = self.model.generate(input_ids, **gen_kwargs)
    
    # Hooks are NOT removed here
    return outputs

Issue: Hooks are added with is_permanent=False but never explicitly removed in _call_with_hooks(). The commented-out self.model.reset_hooks() calls suggest this was intentional, but it's unclear whether TransformerLens automatically cleans up non-permanent hooks or if they accumulate across calls. This could cause hooks to fire on subsequent unrelated forward passes.

Recommendation: Add a finally block to ensure hooks are removed after generation, or document the lifecycle guarantees of is_permanent=False hooks prominently.

Context for Agents
### 7. **Inconsistent Hook Management Pattern**
**File:** `src/ares/contrib/mech_interp/hooked_transformer_client.py`
**Lines:** 92-111

```python
async def _call_with_hooks(self, input_ids, state, **gen_kwargs):
    if self.fwd_hooks is not None:
        for name_or_id_fn, hook_fn in self.fwd_hooks:
            if inspect.iscoroutinefunction(name_or_id_fn):
                name = await name_or_id_fn(state)
            else:
                name = name_or_id_fn
            self.model.add_hook(name, hook_fn, is_permanent=False)
    
    with torch.no_grad():
        outputs = self.model.generate(input_ids, **gen_kwargs)
    
    # Hooks are NOT removed here
    return outputs
```

**Issue:** Hooks are added with `is_permanent=False` but never explicitly removed in `_call_with_hooks()`. The commented-out `self.model.reset_hooks()` calls suggest this was intentional, but it's unclear whether TransformerLens automatically cleans up non-permanent hooks or if they accumulate across calls. This could cause hooks to fire on subsequent unrelated forward passes.

**Recommendation:** Add a `finally` block to ensure hooks are removed after generation, or document the lifecycle guarantees of `is_permanent=False` hooks prominently.

File: src/ares/contrib/mech_interp/hooked_transformer_client.py
Line: 111

Comment on lines +237 to +274


def automatic_activation_capture(model: HookedTransformer) -> ActivationCapture:
"""Create an ActivationCapture that automatically records steps during generation.

This wraps the model's generate method to automatically call start_step() and
end_step() around each generation, making it seamless to use with ARES environments.

Args:
model: HookedTransformer to capture activations from.

Returns:
ActivationCapture instance with automatic step tracking.

Example:
```python
model = HookedTransformer.from_pretrained("gpt2-small")

with automatic_activation_capture(model) as capture:
client = HookedTransformerLLMClient(model=model)
# Now activations are captured automatically during client calls
async with env:
ts = await env.reset()
while not ts.last():
response = await client(ts.observation)
ts = await env.step(response)

trajectory = capture.get_trajectory()
```
"""
capture = ActivationCapture(model)

# Wrap model.generate to auto-capture
original_generate = model.generate

def wrapped_generate(*args, **kwargs):
capture.start_step()
result = original_generate(*args, **kwargs)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important

[Logic] ### 11. Incomplete Docstring for automatic_activation_capture
File: src/ares/contrib/mech_interp/activation_capture.py
Lines: 237-274

def automatic_activation_capture(model: HookedTransformer) -> ActivationCapture:
    """Create an ActivationCapture that automatically records steps during generation.
    ...
    """
    capture = ActivationCapture(model)
    original_generate = model.generate
    def wrapped_generate(*args, **kwargs):
        capture.start_step()
        result = original_generate(*args, **kwargs)
        capture.end_step()
        return result
    model.generate = wrapped_generate
    return capture

Issue: The docstring doesn't warn that this mutates the model object globally. If the same model instance is used elsewhere without automatic_activation_capture, the wrapper persists, causing unintended side effects. Additionally, there's no mechanism to restore the original generate method.

Recommendation: Document this behavior prominently and consider returning a context manager that restores the original method on exit.

Context for Agents
### 11. **Incomplete Docstring for `automatic_activation_capture`**
**File:** `src/ares/contrib/mech_interp/activation_capture.py`
**Lines:** 237-274

```python
def automatic_activation_capture(model: HookedTransformer) -> ActivationCapture:
    """Create an ActivationCapture that automatically records steps during generation.
    ...
    """
    capture = ActivationCapture(model)
    original_generate = model.generate
    def wrapped_generate(*args, **kwargs):
        capture.start_step()
        result = original_generate(*args, **kwargs)
        capture.end_step()
        return result
    model.generate = wrapped_generate
    return capture
```

**Issue:** The docstring doesn't warn that this **mutates the model object globally**. If the same model instance is used elsewhere without `automatic_activation_capture`, the wrapper persists, causing unintended side effects. Additionally, there's no mechanism to restore the original `generate` method.

**Recommendation:** Document this behavior prominently and consider returning a context manager that restores the original method on exit.

File: src/ares/contrib/mech_interp/activation_capture.py
Line: 274

Comment on lines +150 to +161
# Leave room for generation
max_input_tokens = max_position - self.max_new_tokens
input_ids = input_ids[:, :max_input_tokens]
num_input_tokens = input_ids.shape[-1]

# Prepare generation kwargs
gen_kwargs = {
"max_new_tokens": self.max_new_tokens,
**self.generation_kwargs,
}

# TODO: This should be more generic - why temperature specifically?
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important

[Reliability] Context overflow silently truncates input without warning: When num_input_tokens + self.max_new_tokens > max_position, the code truncates input to max_position - max_new_tokens. This can remove critical task context (e.g., the original bug report in a SWE-bench task), causing the agent to fail without user awareness.

Fix: Log a warning when truncation occurs:

if num_input_tokens + self.max_new_tokens > max_position:
    max_input_tokens = max_position - self.max_new_tokens
    _LOGGER.warning(
        "Input truncated from %d to %d tokens (context limit: %d, reserved for generation: %d)",
        num_input_tokens,
        max_input_tokens,
        max_position,
        self.max_new_tokens,
    )
    input_ids = input_ids[:, :max_input_tokens]
    num_input_tokens = input_ids.shape[-1]
Context for Agents
**Context overflow silently truncates input without warning**: When `num_input_tokens + self.max_new_tokens > max_position`, the code truncates input to `max_position - max_new_tokens`. This can remove critical task context (e.g., the original bug report in a SWE-bench task), causing the agent to fail without user awareness.

**Fix**: Log a warning when truncation occurs:

```python
if num_input_tokens + self.max_new_tokens > max_position:
    max_input_tokens = max_position - self.max_new_tokens
    _LOGGER.warning(
        "Input truncated from %d to %d tokens (context limit: %d, reserved for generation: %d)",
        num_input_tokens,
        max_input_tokens,
        max_position,
        self.max_new_tokens,
    )
    input_ids = input_ids[:, :max_input_tokens]
    num_input_tokens = input_ids.shape[-1]
```

File: src/ares/contrib/mech_interp/hooked_transformer_client.py
Line: 161

joshgreaves
joshgreaves previously approved these changes Feb 19, 2026
Copy link
Copy Markdown
Contributor

@joshgreaves joshgreaves left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Narmeen! LGTM

This is a really great example, excited for people to see this case study!

I think given the timeline, the best thing we can do is to just make a small bump to consistency in style. The most obvious is imports. It's a small thing, but we want people to see ultra-consistent usage patterns for ARES.

If you and @rsmith49 have time, it would be great to address the comments if possible before merging. It it's too much, please add a TODO and add to the notion doc.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a handful of improvements I should make here, but we can come back to them later.

pyproject.toml Outdated
"harbor>=0.1.32",
"httpx>=0.28.1",
"jinja2>=3.1.6",
"matplotlib>=3.10.8",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"matplotlib>=3.10.8",

It's not needed by core ARES, and it's already included under examples.

ruff.toml Outdated
allowed-confusables = ["α"]

# Exclude notebooks
exclude = ["*.ipynb"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should still ruff format notebooks 🙃

I recommend:

  • Removing this
  • Running ruff format on your ipynb
  • Fixing any ruff check issues (ok to use # noqa comments where appropriate or disable file-wide for math notation)

ruff.toml Outdated
line-ending = "auto"

# Ignore Notebooks
exclude = ["*.ipynb"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated, also would like it removed

Comment on lines +33 to +36
import asyncio
from transformer_lens import HookedTransformer
from ares.contrib.mech_interp import HookedTransformerLLMClient, ActivationCapture
from ares.environments import swebench_env
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ARES should use consistent style across all documentation (google-style)

Also applies to others in this README

loop = asyncio.get_running_loop()
with ThreadPoolExecutor(max_workers=n_devices) as executor:
futures = [
loop.run_in_executor(executor, _run_device_worker, device, work_queue, MAX_STEPS_PER_EPISODE)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think run_in_executor is necessary, it's a slight overcomplication. I think what we actually want is to make _run_device_worker async, and then just use futures = [asyncio.create_task(...) for .. in ...]

We shouldn't really be using any threading code (including threading.Lock) in asyncio code, unless there is a very specific need to being solved. asyncio.to_thread doesn't count, we should use that where appropriate.

Let's just add a TODO for now, since this is working.

# ---------------------------------------------------------------------------

SEED = 42
np.random.seed(SEED)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't do this in the global scope. Do it in the if __name__ == "__main__" block.

# uv run --no-sync python examples/20q_case_study/phase2_steer.py

import asyncio
from concurrent.futures import ThreadPoolExecutor
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import style

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's one small change to make for consistency, please update imports to be Google-style, i.e. import modules directly, not functions or classes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice notebook!

…ok polish

- Convert all imports to Google style (import modules, not classes/functions)
- Move matplotlib from core deps to examples group in pyproject.toml
- Remove ruff notebook exclusions and fix resulting lint errors
- Add docstrings, CUDA guards, and TODO for async refactor in case study scripts
- Fix install instructions in example 07 and mech_interp README
- Fix typos in 20Q case study README
- Polish notebook: add open problems section, fix citations, update code examples

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Comment on lines +44 to +53
def zero_ablation_hook(activation: torch.Tensor, hook: hook_points.HookPoint) -> torch.Tensor: # noqa: ARG001
ablated = activation.clone()

if heads is not None:
# For attention patterns: [batch, head, query_pos, key_pos]
if len(ablated.shape) == 4:
ablated[:, heads, :, :] = 0.0
# For attention outputs: [batch, pos, head_index, d_head]
elif len(ablated.shape) == 4:
ablated[:, :, heads, :] = 0.0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important

[Logic] In create_zero_ablation_hook the two branches inside the heads is not None block both check len(ablated.shape) == 4:

if len(ablated.shape) == 4:
    ablated[:, heads, :, :] = 0.0
elif len(ablated.shape) == 4:
    ablated[:, :, heads, :] = 0.0

Because the elif uses the same condition as the if, it is never reached. As a result, when the activation is shaped [batch, pos, head_index, d_head] (e.g. attention outputs), no heads are actually zeroed: the hook returns the unmodified tensor. Replace the second condition with a distinct case (or a plain else) so both layout variants are handled:

Suggested change
def zero_ablation_hook(activation: torch.Tensor, hook: hook_points.HookPoint) -> torch.Tensor: # noqa: ARG001
ablated = activation.clone()
if heads is not None:
# For attention patterns: [batch, head, query_pos, key_pos]
if len(ablated.shape) == 4:
ablated[:, heads, :, :] = 0.0
# For attention outputs: [batch, pos, head_index, d_head]
elif len(ablated.shape) == 4:
ablated[:, :, heads, :] = 0.0
if heads is not None:
# Attention patterns: [batch, head, query_pos, key_pos]
if len(ablated.shape) == 4 and ablated.shape[1] == ablated.shape[2]:
ablated[:, heads, :, :] = 0.0
# Attention outputs: [batch, pos, head_index, d_head]
else:
ablated[:, :, heads, :] = 0.0

This ensures that the specified heads are zeroed regardless of whether the head dimension is axis 1 or axis 2.

Context for Agents
In `create_zero_ablation_hook` the two branches inside the `heads is not None` block both check `len(ablated.shape) == 4`:

```python
if len(ablated.shape) == 4:
    ablated[:, heads, :, :] = 0.0
elif len(ablated.shape) == 4:
    ablated[:, :, heads, :] = 0.0
```

Because the `elif` uses the same condition as the `if`, it is never reached. As a result, when the activation is shaped `[batch, pos, head_index, d_head]` (e.g. attention outputs), no heads are actually zeroed: the hook returns the unmodified tensor. Replace the second condition with a distinct case (or a plain `else`) so both layout variants are handled:

```suggestion
        if heads is not None:
            # Attention patterns: [batch, head, query_pos, key_pos]
            if len(ablated.shape) == 4 and ablated.shape[1] == ablated.shape[2]:
                ablated[:, heads, :, :] = 0.0
            # Attention outputs: [batch, pos, head_index, d_head]
            else:
                ablated[:, :, heads, :] = 0.0
```

This ensures that the specified heads are zeroed regardless of whether the head dimension is axis 1 or axis 2.

File: src/ares/contrib/mech_interp/hook_utils.py
Line: 53

Narmeen07 and others added 2 commits February 19, 2026 13:51
Users can now download pre-computed activations and experiment results
from withmartian/ares-20q-case-study on Hugging Face Hub, enabling the
notebook to run without a GPU.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Comment on lines +47 to +55
import json

metadata = {
"model_name": self.model_name,
"num_steps": len(self),
"step_metadata": self.step_metadata,
}
with open(path / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important

[Logic] ⚠️ TrajectoryActivations.save always JSON‑serializes step_metadata, but the example above (capture.record_step_metadata({"action": response})) records full LLMResponse objects that aren’t JSON serializable.

As soon as a user records any non‑primitive metadata (LLM responses, tensors, dataclasses, etc.), this block raises TypeError: Object of type LLMResponse is not JSON serializable and prevents saving the trajectory at all.

To avoid data loss, either (a) coerce arbitrary metadata to something serializable (e.g. json.dump(..., default=str) or custom encoder), or (b) persist metadata via torch.save/pickle like the activation tensors. For example:

with open(path / "metadata.json", "w") as f:
    json.dump(metadata, f, indent=2, default=str)

Without this change the advertised metadata feature can’t be used safely.

Context for Agents
⚠️ **`TrajectoryActivations.save` always JSON‑serializes `step_metadata`, but the example above (`capture.record_step_metadata({"action": response})`) records full `LLMResponse` objects that aren’t JSON serializable.**

As soon as a user records any non‑primitive metadata (LLM responses, tensors, dataclasses, etc.), this block raises `TypeError: Object of type LLMResponse is not JSON serializable` and prevents saving the trajectory at all.

To avoid data loss, either (a) coerce arbitrary metadata to something serializable (e.g. `json.dump(..., default=str)` or custom encoder), or (b) persist metadata via `torch.save`/pickle like the activation tensors. For example:

```python
with open(path / "metadata.json", "w") as f:
    json.dump(metadata, f, indent=2, default=str)
```

Without this change the advertised metadata feature can’t be used safely.

File: src/ares/contrib/mech_interp/activation_capture.py
Line: 55

Comment on lines +112 to +113

return twenty_questions.TwentyQuestionsEnvironment(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important

[Reliability] Unlike HarborSpec.get_env, the new TwentyQuestionsSpec.get_env never validates the selector output. If a caller requests an empty slice/shard (e.g. ares.make("20q:0:0")), selected_objects ends up empty and random.choice(self._objects) in TwentyQuestionsEnvironment.reset() raises IndexError deep inside the episode. This makes presets behave unpredictably and is hard to diagnose. Please guard against empty selections up front and raise a clear ValueError, mirroring the check you already have in HarborSpec.

Suggested change
return twenty_questions.TwentyQuestionsEnvironment(
selected_objects = selector(list(self.objects))
if not selected_objects:
raise ValueError("Task selector produced no objects for Twenty Questions preset")
return twenty_questions.TwentyQuestionsEnvironment(
objects=tuple(selected_objects),
oracle_model=self.oracle_model,
step_limit=self.step_limit,
system_prompt=self.system_prompt,
tracker=tracker,
)
Context for Agents
Unlike `HarborSpec.get_env`, the new `TwentyQuestionsSpec.get_env` never validates the selector output. If a caller requests an empty slice/shard (e.g. `ares.make("20q:0:0")`), `selected_objects` ends up empty and `random.choice(self._objects)` in `TwentyQuestionsEnvironment.reset()` raises `IndexError` deep inside the episode. This makes presets behave unpredictably and is hard to diagnose. Please guard against empty selections up front and raise a clear `ValueError`, mirroring the check you already have in `HarborSpec`.

```suggestion
        selected_objects = selector(list(self.objects))
        if not selected_objects:
            raise ValueError("Task selector produced no objects for Twenty Questions preset")

        return twenty_questions.TwentyQuestionsEnvironment(
            objects=tuple(selected_objects),
            oracle_model=self.oracle_model,
            step_limit=self.step_limit,
            system_prompt=self.system_prompt,
            tracker=tracker,
        )
```

File: src/ares/presets.py
Line: 113

Copy link
Copy Markdown
Contributor

@joshgreaves joshgreaves left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@rsmith49 rsmith49 merged commit d279d3c into main Feb 19, 2026
2 of 3 checks passed
@rsmith49 rsmith49 deleted the narmeen/mi-q20 branch February 19, 2026 16:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants