Skip to content
Merged
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ repos:
- id: isort
name: isort (python)
args: ["--profile", "black", "--filter-files"]
- repo: https://github.com/ibm/detect-secrets
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this and the following lines be committed or were they only for debugging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's intended. We want isort consistent with black and only on tracked files, hence --profile black --filter-files.

rev: 0.13.1+ibm.62.dss
hooks:
- id: detect-secrets
additional_dependencies: [ boxsdk<4 ]
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"files": "^.secrets.baseline$",
"lines": null
},
"generated_at": "2025-07-26T17:25:52Z",
"generated_at": "2026-01-08T19:07:43Z",
"plugins_used": [
{
"name": "AWSKeyDetector"
Expand Down Expand Up @@ -187,7 +187,7 @@
}
]
},
"version": "0.13.1+ibm.61.dss",
"version": "0.13.1+ibm.62.dss",
"word_list": {
"file": null,
"hash": null
Expand Down
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ pytest -m "vllm"
If you find ICX360 useful, please star the repository and cite our work as follows:
```
@misc{wei2025icx360,
title={{ICX360}: {In-Context eXplainability} 360 Toolkit},
title={{ICX360}: {In-Context eXplainability} 360 Toolkit},
author={Dennis Wei and Ronny Luss and Xiaomeng Hu and Lucas Monteiro Paes and Pin-Yu Chen and Karthikeyan Natesan Ramamurthy and Erik Miehling and Inge Vejsbjerg and Hendrik Strobelt},
year={2025},
eprint={2511.10879},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2511.10879},
url={https://arxiv.org/abs/2511.10879},
}
```

Expand All @@ -139,4 +139,3 @@ Lets form a community around this toolkit! Ask a question, raise an issue, or ex
## IBM ❤️ Open Source AI

The first release of ICX360 has been brought to you by IBM in the hope of building a larger community around this topic.

6 changes: 5 additions & 1 deletion icx360/algorithms/mexgen/mexgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from icx360.algorithms.lbbe import LocalBBExplainer
from icx360.utils.model_wrappers import GeneratedOutput, HFModel
from icx360.utils.scalarizers import ProbScalarizedModel, TextScalarizedModel
from icx360.utils.segmenters import SpaCySegmenter, exclude_non_alphanumeric, merge_non_alphanumeric
from icx360.utils.segmenters import (
SpaCySegmenter,
exclude_non_alphanumeric,
merge_non_alphanumeric,
)


class MExGenExplainer(LocalBBExplainer):
Expand Down
2 changes: 1 addition & 1 deletion icx360/utils/model_wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
"""

from .base_model_wrapper import GeneratedOutput, Model
from .huggingface import HFModel
from .huggingface import HFModel, PipelineHFModel
from .vllm import VLLMModel
106 changes: 106 additions & 0 deletions icx360/utils/model_wrappers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,109 @@ def generate(self, inputs, chat_template=False, system_prompt=None, tokenizer_kw
output_obj = GeneratedOutput(output_ids=output_ids, output_text=output_text, output_token_count=output_token_count)

return output_obj


class PipelineHFModel(HFModel):
"""
HFModel-compatible wrapper around a SteeringPipeline.

Attributes:
_pipeline (aisteer360.algorithms.core.SteeringPipeline):
AISteer360 SteeringPipeline object.
_model (transformers model object):
Underlying model object.
_tokenizer (transformers tokenizer):
Tokenizer corresponding to model.
_device (str):
Device on which the model resides.
_runtime_kwargs (dict or None):
Optional per-call parameters for controls at runtime.
"""

def __init__(self, pipeline, tokenizer, runtime_kwargs: dict | None = None):
"""
Initialize PipelineHFModel wrapper.

Args:
pipeline (aisteer360.algorithms.core.SteeringPipeline):
AISteer360 SteeringPipeline object.
tokenizer (transformers tokenizer):
Tokenizer corresponding to model.
runtime_kwargs (dict or None):
Optional per-call parameters for controls at runtime.
"""
super().__init__(pipeline.model, tokenizer)

self._pipeline = pipeline
self._runtime_kwargs = runtime_kwargs

def generate(
self,
inputs,
chat_template: bool = False,
system_prompt: str | None = None,
tokenizer_kwargs: dict = {},
text_only: bool = True,
**kwargs,
):
"""
Generate response from SteeringPipeline.

Args:
inputs (str or List[str] or List[List[str]]):
A single input text, a list of input texts, or a list of segmented texts.
chat_template (bool):
Whether to apply chat template.
system_prompt (str or None):
System prompt to include in chat template.
tokenizer_kwargs (dict):
Additional keyword arguments for tokenizer.
text_only (bool):
Return only generated text (default) or an object containing additional outputs.
**kwargs (dict):
Additional keyword arguments for pipeline.

Returns:
output_obj (List[str] or icx360.utils.model_wrappers.GeneratedOutput):
If text_only == True, a list of generated texts corresponding to inputs.
If text_only == False, a GeneratedOutput object containing the following:
output_ids: (num_inputs, output_token_count) torch.Tensor of generated token IDs.
output_text: List of generated texts.
output_token_count: Maximum number of generated tokens.
"""

encoding = self.convert_input(
inputs,
chat_template=chat_template,
system_prompt=system_prompt,
**tokenizer_kwargs,
)
input_ids = encoding["input_ids"]
attention_mask = encoding["attention_mask"]
input_length = input_ids.shape[1]

runtime_kwargs = self._runtime_kwargs

with torch.no_grad():
output_ids = self._pipeline.generate(
input_ids=input_ids,
attention_mask=attention_mask,
runtime_kwargs=runtime_kwargs,
**kwargs,
)
# SteeringPipeline already truncates output tokens to generated tokens only, don't truncate again

output_text = self._tokenizer.batch_decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)

if text_only:
return output_text
else:
return GeneratedOutput(
output_ids=output_ids,
output_text=output_text,
output_token_count=output_ids.shape[1],
)
63 changes: 61 additions & 2 deletions icx360/utils/scalarizers/prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from icx360.utils.model_wrappers import HFModel, VLLMModel
from icx360.utils.model_wrappers import HFModel, PipelineHFModel, VLLMModel
from icx360.utils.scalarizers import Scalarizer
from icx360.utils.segmenters import find_unit_boundaries
from icx360.utils.toma import toma_get_probs
Expand Down Expand Up @@ -84,7 +84,9 @@ def scalarize_output(self, inputs=None, outputs=None, ref_input=None, ref_output

# Compute log probabilities of reference output tokens conditioned on inputs
# Also find token boundaries of units of the reference output
if isinstance(self.model, HFModel):
if isinstance(self.model, PipelineHFModel):
log_probs, boundaries = self._compute_log_probs_pipeline(inputs, ref_output, **kwargs)
elif isinstance(self.model, HFModel):
log_probs, boundaries = self._compute_log_probs_hf(inputs, ref_output, **kwargs)
elif isinstance(self.model, VLLMModel):
log_probs, boundaries = self._compute_log_probs_vllm(inputs, ref_output, **kwargs)
Expand Down Expand Up @@ -178,6 +180,63 @@ def _compute_log_probs_hf(self, inputs, ref_output, **kwargs):

return log_probs, boundaries

def _compute_log_probs_pipeline(self, inputs, ref_output, **kwargs):
"""
Compute log probabilities of reference output tokens conditioned on inputs when self.model is a PipelineHFModel.

Delegates to the underlying SteeringPipeline.compute_logprobs.

Args:
inputs (transformers.BatchEncoding):
BatchEncoding of inputs produced by tokenizer.
ref_output (icx360.utils.model_wrappers.GeneratedOutput):
Reference output object containing a sequence of token IDs (ref_output.output_ids).
**kwargs (dict):
Additional keyword arguments for model.

Returns:
log_probs ((num_inputs, gen_length) torch.Tensor):
Log probabilities of reference output tokens.
boundaries (List[int]):
Token boundaries of units of the reference output.
"""
if not isinstance(self.model, PipelineHFModel):
raise TypeError("_compute_log_probs_pipeline requires a PipelineHFModel")

pipeline_model = self.model # icx360.utils.model_wrappers.PipelineHFModel
pipeline = pipeline_model._pipeline # aisteer360.algorithms.core.SteeringPipeline

# inputs is a transformers.BatchEncoding from convert_input()
input_ids = inputs["input_ids"]
attention_mask = inputs.get("attention_mask", None)

# reference output token IDs
ref_output_ids = ref_output.output_ids
device = pipeline_model._device

if ref_output_ids.device != device:
ref_output_ids = ref_output_ids.to(device)

with torch.no_grad():
log_probs = pipeline.compute_logprobs(
input_ids=input_ids,
attention_mask=attention_mask,
ref_output_ids=ref_output_ids,
runtime_kwargs=pipeline_model._runtime_kwargs,
**kwargs,
)

# Get list of reference output tokens
tokens = []
for id in ref_output_ids[0]:
tokens.append("" if id in self.model._tokenizer.all_special_ids else self.model._tokenizer.decode(id))
# Find token boundaries of units of the reference output
boundaries = find_unit_boundaries(ref_output.output_text[0], tokens)

# log_probs must be shape: (num_inputs, gen_length)
return log_probs, boundaries


def _compute_log_probs_vllm(self, inputs, ref_output, max_inputs_per_call=200, **kwargs):
"""
Compute log probabilities of reference output tokens conditioned on inputs for a VLLMModel.
Expand Down
6 changes: 5 additions & 1 deletion icx360/utils/segmenters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@
"""

from .spacy import SpaCySegmenter
from .utils import exclude_non_alphanumeric, find_unit_boundaries, merge_non_alphanumeric
from .utils import (
exclude_non_alphanumeric,
find_unit_boundaries,
merge_non_alphanumeric,
)