Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
345 changes: 216 additions & 129 deletions examples/mexgen/RAG.ipynb

Large diffs are not rendered by default.

271 changes: 166 additions & 105 deletions examples/mexgen/question_answering.ipynb

Large diffs are not rendered by default.

1,045 changes: 524 additions & 521 deletions examples/mexgen/summarization.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions icx360/algorithms/mexgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
Module containing submodules for MExGen C-LIME and MExGen L-SHAP explainers
"""

from .mexgen import MExGenExplainer
from .clime import CLIME
from .lshap import LSHAP
64 changes: 12 additions & 52 deletions icx360/algorithms/mexgen/clime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
import numpy as np
from sklearn.linear_model import LinearRegression, lars_path

from icx360.algorithms.lbbe import LocalBBExplainer
from icx360.algorithms.mexgen import MExGenExplainer
from icx360.utils.scalarizers import ProbScalarizedModel, TextScalarizedModel
from icx360.utils.segmenters import SpaCySegmenter, exclude_non_alphanumeric
from icx360.utils.subset_utils import mask_subsets, sample_subsets


class CLIME(LocalBBExplainer):
class CLIME(MExGenExplainer):
"""
MExGen C-LIME explainer

Expand All @@ -31,43 +31,11 @@ class CLIME(LocalBBExplainer):
"Scalarized model" that further wraps `model` with a method for computing scalar values
based on the model's inputs or outputs.
"""
def __init__(self, model, segmenter="en_core_web_trf", scalarizer="prob", **kwargs):
"""
Initialize MExGen C-LIME explainer.

Args:
model (icx360.utils.model_wrappers.Model):
Model to explain, wrapped in an icx360.utils.model_wrappers.Model object.
segmenter (str):
Name of spaCy model to use in segmenter (icx360.utils.segmenters.SpaCySegmenter).
scalarizer (str):
Type of scalarizer to use.
"prob": probability of generating original output conditioned on perturbed inputs
(instantiates an icx360.utils.scalarizers.ProbScalarizedModel).
"text": similarity scores between original output and perturbed outputs
(instantiates an icx360.utils.scalarizers.TextScalarizedModel).
**kwargs (dict):
Additional keyword arguments for initializing scalarizer.

Raises:
ValueError: If `scalarizer` is not "prob" or "text".
"""
self.model = model

# Instantiate segmenter
self.segmenter = SpaCySegmenter(segmenter)

# Instantiate scalarized model
if scalarizer == "prob":
self.scalarized_model = ProbScalarizedModel(model)
elif scalarizer == "text":
self.scalarized_model = TextScalarizedModel(model, **kwargs)
else:
raise ValueError("Scalarizer not supported")

def explain_instance(self, input_orig, unit_types="p", ind_segment=True, segment_type="s", max_phrase_length=10,
model_params={}, scalarize_params={}, oversampling_factor=10, max_units_replace=2,
empty_subset=True, replacement_str="", num_nonzeros=None, debias=True):
def explain_instance(self, input_orig, unit_types="p", output_orig=None,
ind_segment=True, segment_type="s", max_phrase_length=10,
model_params={}, scalarize_params={},
oversampling_factor=10, max_units_replace=2, empty_subset=True, replacement_str="",
num_nonzeros=None, debias=True):
"""
Explain model output by attributing it to parts of the input text.

Expand All @@ -82,6 +50,8 @@ def explain_instance(self, input_orig, unit_types="p", ind_segment=True, segment
"p" for paragraph, "s" for sentence, "w" for word,
"n" for not to be perturbed/attributed to.
If str, applies to all units in input_orig, otherwise unit-specific.
output_orig (str or List[str] or icx360.utils.model_wrappers.GeneratedOutput or None):
[output] Output for original input if provided, otherwise None.
ind_segment (bool or List[bool]):
[segmentation] Whether to segment input text.
If bool, applies to all units; if List[bool], applies to each unit individually.
Expand Down Expand Up @@ -126,21 +96,11 @@ def explain_instance(self, input_orig, unit_types="p", ind_segment=True, segment
One or more sets of attribution scores (labelled by the type of scalarizer).
"""
# 1) Segment input text if needed
if type(ind_segment) is bool:
ind_segment = [ind_segment]
if type(input_orig) is str or any(ind_segment):
# Call segmenter
input_orig, unit_types, _ = self.segmenter.segment_units(input_orig, ind_segment, unit_types, segment_type=segment_type, max_phrase_length=max_phrase_length)
# Exclude units without alphanumeric characters from perturbation
unit_types = exclude_non_alphanumeric(unit_types, input_orig)
input_orig, unit_types = self.segment_input(input_orig, unit_types, ind_segment, segment_type, max_phrase_length)
num_units = len(input_orig)

if type(unit_types) is str:
# Expand to list
unit_types = [unit_types] * num_units

# 2) Generate output for original input
output_orig = self.model.generate([input_orig], text_only=False, **model_params)
# 2) Generate output for original input or wrap provided output
output_orig = self.generate_or_wrap_output(input_orig, output_orig, model_params)

# 3) Enumerate subsets of units that will be perturbed/replaced
idx_replace = (np.array(unit_types) != "n").nonzero()[0]
Expand Down
62 changes: 10 additions & 52 deletions icx360/algorithms/mexgen/lshap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@

import numpy as np

from icx360.algorithms.lbbe import LocalBBExplainer
from icx360.algorithms.mexgen import MExGenExplainer
from icx360.utils.scalarizers import ProbScalarizedModel, TextScalarizedModel
from icx360.utils.segmenters import SpaCySegmenter, exclude_non_alphanumeric
from icx360.utils.subset_utils import mask_subsets, sample_subsets


class LSHAP(LocalBBExplainer):
class LSHAP(MExGenExplainer):
"""
MExGen L-SHAP explainer

Expand All @@ -32,42 +31,9 @@ class LSHAP(LocalBBExplainer):
"Scalarized model" that further wraps `model` with a method for computing scalar values
based on the model's inputs or outputs.
"""
def __init__(self, model, segmenter="en_core_web_trf", scalarizer="prob", **kwargs):
"""
Initialize MExGen L-SHAP explainer.

Args:
model (icx360.utils.model_wrappers.Model):
Model to explain, wrapped in an icx360.utils.model_wrappers.Model object.
segmenter (str):
Name of spaCy model to use in segmenter (icx360.utils.segmenters.SpaCySegmenter).
scalarizer (str):
Type of scalarizer to use.
"prob": probability of generating original output conditioned on perturbed inputs
(instantiates an icx360.utils.scalarizers.ProbScalarizedModel).
"text": similarity scores between original output and perturbed outputs
(instantiates an icx360.utils.scalarizers.TextScalarizedModel).
**kwargs (dict):
Additional keyword arguments for initializing scalarizer.

Raises:
ValueError: If `scalarizer` is not "prob" or "text".
"""
self.model = model

# Instantiate segmenter
self.segmenter = SpaCySegmenter(segmenter)

# Instantiate scalarized model
if scalarizer == "prob":
self.scalarized_model = ProbScalarizedModel(model)
elif scalarizer == "text":
self.scalarized_model = TextScalarizedModel(model, **kwargs)
else:
raise ValueError("Scalarizer not supported")

def explain_instance(self, input_orig, unit_types="p", ind_interest=None, ind_segment=True, segment_type="s",
max_phrase_length=10, model_params={}, scalarize_params={},
def explain_instance(self, input_orig, unit_types="p", ind_interest=None, output_orig=None,
ind_segment=True, segment_type="s", max_phrase_length=10,
model_params={}, scalarize_params={},
num_neighbors=2, max_units_replace=2, replacement_str=""):
"""
Explain model output by attributing it to parts of the input text.
Expand All @@ -86,6 +52,8 @@ def explain_instance(self, input_orig, unit_types="p", ind_interest=None, ind_se
ind_interest (bool or List[bool] or None):
[input] Indicator of units to attribute to ("of interest").
Default None means np.array(unit_types) != "n".
output_orig (str or List[str] or icx360.utils.model_wrappers.GeneratedOutput or None):
[output] Output for original input if provided, otherwise None.
ind_segment (bool or List[bool]):
[segmentation] Whether to segment input text.
If bool, applies to all units; if List[bool], applies to each unit individually.
Expand Down Expand Up @@ -123,19 +91,9 @@ def explain_instance(self, input_orig, unit_types="p", ind_interest=None, ind_se
One or more sets of attribution scores (labelled by the type of scalarizer).
"""
# 1) Segment input text if needed
if type(ind_segment) is bool:
ind_segment = [ind_segment]
if type(input_orig) is str or any(ind_segment):
# Call segmenter
input_orig, unit_types, _ = self.segmenter.segment_units(input_orig, ind_segment, unit_types, segment_type=segment_type, max_phrase_length=max_phrase_length)
# Exclude units without alphanumeric characters from perturbation
unit_types = exclude_non_alphanumeric(unit_types, input_orig)
input_orig, unit_types = self.segment_input(input_orig, unit_types, ind_segment, segment_type, max_phrase_length)
num_units = len(input_orig)

# Expand to list if needed
if type(unit_types) is str:
unit_types = [unit_types] * num_units

if ind_interest is None:
# Default is to attribute to all units that can be perturbed
ind_interest = np.array(unit_types) != "n"
Expand All @@ -146,8 +104,8 @@ def explain_instance(self, input_orig, unit_types="p", ind_interest=None, ind_se
idx_interest = ind_interest.nonzero()[0]
idx_replace = (np.array(unit_types) != "n").nonzero()[0]

# 2) Generate output for original input
output_orig = self.model.generate([input_orig], text_only=False, **model_params)
# 2) Generate output for original input or wrap provided output
output_orig = self.generate_or_wrap_output(input_orig, output_orig, model_params)

# 3) Initialize quantities
# Initialize importance scores
Expand Down
147 changes: 147 additions & 0 deletions icx360/algorithms/mexgen/mexgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
Base class for MExGen explainers.

The MExGen framework is described in:
Multi-Level Explanations for Generative Language Models.
Lucas Monteiro Paes and Dennis Wei et al.
The 63rd Annual Meeting of the Association for Computational Linguistics (ACL 2025).
https://arxiv.org/abs/2403.14459
"""
# Assisted by watsonx Code Assistant in formatting and augmenting docstrings.

from icx360.algorithms.lbbe import LocalBBExplainer
from icx360.utils.model_wrappers import GeneratedOutput, HFModel
from icx360.utils.scalarizers import ProbScalarizedModel, TextScalarizedModel
from icx360.utils.segmenters import SpaCySegmenter, exclude_non_alphanumeric


class MExGenExplainer(LocalBBExplainer):
"""
Base class for MExGen explainers.

Attributes:
model (icx360.utils.model_wrappers.Model):
Model to explain, wrapped in an icx360.utils.model_wrappers.Model object.
segmenter (icx360.utils.segmenters.SpaCySegmenter):
Object for segmenting input text into units using a spaCy model.
scalarized_model (icx360.utils.scalarizers.Scalarizer):
"Scalarized model" that further wraps `model` with a method for computing scalar values
based on the model's inputs or outputs.
"""
def __init__(self, model, segmenter="en_core_web_trf", scalarizer="prob", **kwargs):
"""
Initialize MExGen explainer.

Args:
model (icx360.utils.model_wrappers.Model):
Model to explain, wrapped in an icx360.utils.model_wrappers.Model object.
segmenter (str):
Name of spaCy model to use in segmenter (icx360.utils.segmenters.SpaCySegmenter).
scalarizer (str):
Type of scalarizer to use.
"prob": probability of generating original output conditioned on perturbed inputs
(instantiates an icx360.utils.scalarizers.ProbScalarizedModel).
"text": similarity scores between original output and perturbed outputs
(instantiates an icx360.utils.scalarizers.TextScalarizedModel).
**kwargs (dict):
Additional keyword arguments for initializing scalarizer.

Raises:
ValueError: If `scalarizer` is not "prob" or "text".
"""
self.model = model

# Instantiate segmenter
self.segmenter = SpaCySegmenter(segmenter)

# Instantiate scalarized model
if scalarizer == "prob":
self.scalarized_model = ProbScalarizedModel(model)
elif scalarizer == "text":
self.scalarized_model = TextScalarizedModel(model, **kwargs)
else:
raise ValueError("Scalarizer not supported")

def segment_input(self, input_orig, unit_types="p", ind_segment=True, segment_type="s", max_phrase_length=10):
"""
Segment input text (if needed).

Args:
input_orig (str or List[str]):
Input text as a single unit (if str) or segmented sequence of units (List[str]).
unit_types (str or List[str]):
Types of units in input_orig.
"p" for paragraph, "s" for sentence, "w" for word,
"n" for not to be perturbed/attributed to.
If str, applies to all units in input_orig, otherwise unit-specific.
ind_segment (bool or List[bool]):
Whether to segment input text.
If bool, applies to all units; if List[bool], applies to each unit individually.
segment_type (str):
Type of units to segment into: "s" for sentences, "w" for words, "ph" for phrases.
max_phrase_length (int):
Maximum phrase length in terms of spaCy tokens (default 10).

Returns:
input_orig (List[str]):
Segmented input text.
unit_types (List[str]):
Updated types of units.
"""
# Convert ind_segment to list if needed
if type(ind_segment) is bool:
ind_segment = [ind_segment]
# Segment input text if needed
if type(input_orig) is str or any(ind_segment):
# Call segmenter
input_orig, unit_types, _ = self.segmenter.segment_units(input_orig, ind_segment, unit_types,
segment_type=segment_type,
max_phrase_length=max_phrase_length)
# Exclude units without alphanumeric characters from perturbation
unit_types = exclude_non_alphanumeric(unit_types, input_orig)
num_units = len(input_orig)

# Expand to list if needed
if type(unit_types) is str:
unit_types = [unit_types] * num_units

return input_orig, unit_types

def generate_or_wrap_output(self, input_orig, output_orig=None, model_params={}):
"""
Generate output for original input or wrap provided output in a GeneratedOutput object

Args:
input_orig (List[str]):
Original input segmented into units.
output_orig (str or List[str] or icx360.utils.model_wrappers.GeneratedOutput or None):
Output for original input if provided, otherwise None.
model_params (dict):
Additional keyword arguments for model generation (for the self.model.generate() method).

Returns:
output_orig (icx360.utils.model_wrappers.GeneratedOutput):
Object containing output for original input.

Raises:
TypeError: If `output_orig` is not str, List[str], GeneratedOutput, or None.
"""
if output_orig is None:
# Generate output for original input
output_orig = self.model.generate([input_orig], text_only=False, **model_params)
elif type(output_orig) in (str, list):
if type(output_orig) is str:
output_orig = [output_orig]

# Wrap output text in a GeneratedOutput object
output_orig = GeneratedOutput(output_text=output_orig)

if isinstance(self.model, HFModel):
# Also include output token IDs for HFModel
output_orig.output_ids = self.model.convert_input(output_orig.output_text)["input_ids"]
output_orig.output_token_count = output_orig.output_ids.shape[1]

elif not isinstance(output_orig, GeneratedOutput):
raise TypeError("output_orig must be a str, List[str], GeneratedOutput, or None.")

return output_orig