Skip to content

Commit 5bd1c76

Browse files
authored
Merge pull request #11 from IBM/mexgen
Handle outputs segmented into multiple units (with ProbScalarizedModel only)
2 parents c34e237 + 96cafa4 commit 5bd1c76

11 files changed

Lines changed: 925 additions & 756 deletions

File tree

examples/mexgen/RAG.ipynb

Lines changed: 98 additions & 129 deletions
Large diffs are not rendered by default.

examples/mexgen/question_answering.ipynb

Lines changed: 80 additions & 80 deletions
Large diffs are not rendered by default.

examples/mexgen/summarization.ipynb

Lines changed: 403 additions & 396 deletions
Large diffs are not rendered by default.

icx360/algorithms/mexgen/clime.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class CLIME(MExGenExplainer):
3232
based on the model's inputs or outputs.
3333
"""
3434
def explain_instance(self, input_orig, unit_types="p", output_orig=None,
35-
ind_segment=True, segment_type="s", max_phrase_length=10,
35+
ind_segment=True, segment_type="s", max_phrase_length=10, segment_type_output=None,
3636
model_params={}, scalarize_params={},
3737
oversampling_factor=10, max_units_replace=2, empty_subset=True, replacement_str="",
3838
num_nonzeros=None, debias=True):
@@ -51,14 +51,18 @@ def explain_instance(self, input_orig, unit_types="p", output_orig=None,
5151
"n" for not to be perturbed/attributed to.
5252
If str, applies to all units in input_orig, otherwise unit-specific.
5353
output_orig (str or List[str] or icx360.utils.model_wrappers.GeneratedOutput or None):
54-
[output] Output for original input if provided, otherwise None.
54+
[output] Output for original input.
55+
Can be a single unit (str), segmented into units (List[str]), a GeneratedOutput object, or None.
5556
ind_segment (bool or List[bool]):
5657
[segmentation] Whether to segment input text.
5758
If bool, applies to all units; if List[bool], applies to each unit individually.
5859
segment_type (str):
5960
[segmentation] Type of units to segment into: "s" for sentences, "w" for words, "ph" for phrases.
6061
max_phrase_length (int):
6162
[segmentation] Maximum phrase length in terms of spaCy tokens (default 10).
63+
segment_type_output (str or None):
64+
[segmentation] Type of units to segment output text into:
65+
"s" for sentences, "ph" for phrases, None for no segmentation.
6266
model_params (dict):
6367
Additional keyword arguments for model generation (for the self.model.generate() method).
6468
scalarize_params (dict):
@@ -101,6 +105,8 @@ def explain_instance(self, input_orig, unit_types="p", output_orig=None,
101105

102106
# 2) Generate output for original input or wrap provided output
103107
output_orig = self.generate_or_wrap_output(input_orig, output_orig, model_params)
108+
# Segment output text if needed
109+
output_orig = self.segment_output(output_orig, segment_type_output, max_phrase_length)
104110

105111
# 3) Enumerate subsets of units that will be perturbed/replaced
106112
idx_replace = (np.array(unit_types) != "n").nonzero()[0]
@@ -130,7 +136,7 @@ def explain_instance(self, input_orig, unit_types="p", output_orig=None,
130136
coef[key], intercept[key], num_nonzeros_out[key] = fit_linear_model(features, target[key].cpu().numpy(), subset_weights, num_nonzeros, debias)
131137

132138
else:
133-
# Single target vector
139+
# Single target array (could contain multiple columns)
134140
coef, intercept, num_nonzeros_out = fit_linear_model(features, target.cpu().numpy(), subset_weights, num_nonzeros, debias)
135141

136142
# 8) Construct output dictionary
@@ -186,8 +192,8 @@ def fit_linear_model(features, target, sample_weights, num_nonzeros, debias):
186192
Args:
187193
features ((num_perturb, num_units) np.ndarray):
188194
Feature values.
189-
target ((num_perturb,) np.ndarray):
190-
Target values to predict.
195+
target ((num_perturb,) or (num_perturb, num_output_units) np.ndarray):
196+
Target values to predict (one column for each output unit).
191197
sample_weights ((num_perturb,) np.ndarray):
192198
Sample weights.
193199
num_nonzeros (int or None):
@@ -196,51 +202,74 @@ def fit_linear_model(features, target, sample_weights, num_nonzeros, debias):
196202
Refit linear model with no penalty after selecting features.
197203
198204
Returns:
199-
coef ((num_units,) np.ndarray):
200-
Coefficients of linear model.
201-
intercept (float):
202-
Intercept of linear model.
203-
num_nonzeros (int):
204-
Actual number of non-zero coefficients.
205+
coef ((num_units,) or (num_units, num_output_units) np.ndarray):
206+
Coefficients of linear model(s) (one per output unit).
207+
intercept (float or (num_output_units,) np.ndarray):
208+
Intercept(s) of linear model(s) (one per output unit).
209+
num_nonzeros (List[int]):
210+
Actual numbers of non-zero coefficients.
205211
"""
206212
num_units = features.shape[1]
213+
# Promote target array to 2D if needed
214+
target = target[:, None] if target.ndim == 1 else target
215+
num_output_units = target.shape[1]
207216

208217
if num_nonzeros is None:
209218
# Fit dense linear model over the units that were perturbed (`active`)
210219
active = features.any(axis=0).nonzero()[0]
211-
coef = np.zeros(num_units)
220+
coef = np.zeros((num_units, num_output_units))
212221
lr = LinearRegression()
213222
lr.fit(features[:, active], target, sample_weight=sample_weights)
214-
coef[active] = lr.coef_
223+
coef[active, :] = lr.coef_.T
215224
intercept = lr.intercept_
216225

217226
else:
218227
# Fit sparse linear model
219228

220229
# Center feature and target values
221230
features_mean = features.mean(axis=0)
222-
target_mean = target.mean()
231+
target_mean = target.mean(axis=0)
223232
features_centered = features - features_mean
224233
target_centered = target - target_mean
225234

226-
# Call lars_path to obtain sparse linear model with num_nonzeros coefficients
227-
# NOTE: may return fewer than num_nonzeros if coefficients leave the active set
228-
alphas, active, coef = lars_path(np.sqrt(sample_weights)[:, None] * features_centered, np.sqrt(sample_weights) * target_centered, max_iter=num_nonzeros, method="lasso", return_path=False)
229-
230-
if debias:
231-
coef = np.zeros(num_units)
232-
if len(active):
233-
# Refit linear model on selected features with no penalty
234-
lr = LinearRegression()
235-
lr.fit(features[:, active], target, sample_weight=sample_weights)
236-
coef[active] = lr.coef_
237-
intercept = lr.intercept_
235+
# Initialize outputs
236+
coef = np.zeros((num_units, num_output_units))
237+
intercept = np.zeros(num_output_units)
238+
active = [None] * num_output_units
239+
240+
# Iterate over output units
241+
for u in range(num_output_units):
242+
# Call lars_path to obtain sparse linear model with num_nonzeros coefficients
243+
# NOTE: may return fewer than num_nonzeros if coefficients leave the active set
244+
alphas, active[u], coef[:, u] = lars_path(np.sqrt(sample_weights)[:, None] * features_centered,
245+
np.sqrt(sample_weights) * target_centered[:, u],
246+
max_iter=num_nonzeros,
247+
method="lasso",
248+
return_path=False)
249+
250+
if debias:
251+
coef[:, u] = np.zeros(num_units)
252+
if len(active[u]):
253+
# Refit linear model on selected features with no penalty
254+
lr = LinearRegression()
255+
lr.fit(features[:, active[u]], target[:, u], sample_weight=sample_weights)
256+
coef[active[u], u] = lr.coef_
257+
intercept[u] = lr.intercept_
258+
else:
259+
# No active set, coefficients all zero
260+
intercept[u] = target_mean[u]
238261
else:
239-
# No active set, coefficients all zero
240-
intercept = target_mean
241-
else:
242-
# Compute intercept to account for centering
243-
intercept = target_mean - coef @ features_mean
244-
262+
# Compute intercept to account for centering
263+
intercept[u] = target_mean[u] - coef[:, u] @ features_mean
264+
265+
if num_output_units == 1:
266+
coef, intercept = coef.squeeze(axis=1), intercept.squeeze()
267+
# Actual number(s) of non-zero coefficients
268+
if type(active[0]) is int:
269+
# Single active set (single list of indices) so number of non-zeros is same for all output units
270+
num_nonzeros = [len(active)] * num_output_units
271+
else:
272+
# Multiple active sets, one for each output unit
273+
num_nonzeros = map(len, active)
245274
# Negate coefficients so that important units have positive coefficients
246-
return -coef, intercept, len(active)
275+
return -coef, intercept, num_nonzeros

icx360/algorithms/mexgen/lshap.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class LSHAP(MExGenExplainer):
3232
based on the model's inputs or outputs.
3333
"""
3434
def explain_instance(self, input_orig, unit_types="p", ind_interest=None, output_orig=None,
35-
ind_segment=True, segment_type="s", max_phrase_length=10,
35+
ind_segment=True, segment_type="s", max_phrase_length=10, segment_type_output=None,
3636
model_params={}, scalarize_params={},
3737
num_neighbors=2, max_units_replace=2, replacement_str=""):
3838
"""
@@ -53,14 +53,18 @@ def explain_instance(self, input_orig, unit_types="p", ind_interest=None, output
5353
[input] Indicator of units to attribute to ("of interest").
5454
Default None means np.array(unit_types) != "n".
5555
output_orig (str or List[str] or icx360.utils.model_wrappers.GeneratedOutput or None):
56-
[output] Output for original input if provided, otherwise None.
56+
[output] Output for original input.
57+
Can be a single unit (str), segmented into units (List[str]), a GeneratedOutput object, or None.
5758
ind_segment (bool or List[bool]):
5859
[segmentation] Whether to segment input text.
5960
If bool, applies to all units; if List[bool], applies to each unit individually.
6061
segment_type (str):
6162
[segmentation] Type of units to segment into: "s" for sentences, "w" for words, "ph" for phrases.
6263
max_phrase_length (int):
6364
[segmentation] Maximum phrase length in terms of spaCy tokens (default 10).
65+
segment_type_output (str or None):
66+
[segmentation] Type of units to segment output text into:
67+
"s" for sentences, "ph" for phrases, None for no segmentation.
6468
model_params (dict):
6569
Additional keyword arguments for model generation (for the self.model.generate() method).
6670
scalarize_params (dict):
@@ -106,6 +110,9 @@ def explain_instance(self, input_orig, unit_types="p", ind_interest=None, output
106110

107111
# 2) Generate output for original input or wrap provided output
108112
output_orig = self.generate_or_wrap_output(input_orig, output_orig, model_params)
113+
# Segment output text if needed
114+
output_orig = self.segment_output(output_orig, segment_type_output, max_phrase_length)
115+
num_output_units = 1 if type(output_orig.output_text[0]) is str else len(output_orig.output_text[0])
109116

110117
# 3) Initialize quantities
111118
# Initialize importance scores
@@ -115,7 +122,7 @@ def explain_instance(self, input_orig, unit_types="p", ind_interest=None, output
115122
for key in self.scalarized_model.sim_scores:
116123
importance_scores[key] = np.zeros(num_units)
117124
else:
118-
importance_scores = np.zeros(num_units)
125+
importance_scores = np.zeros((num_units, num_output_units))
119126

120127
# Initialize quantities associated with units of interest
121128
idx_replace_i = [None] * len(idx_interest)
@@ -187,21 +194,23 @@ def explain_instance(self, input_orig, unit_types="p", ind_interest=None, output
187194
importance_scores[key][idx_interest[i]] = np.inner(scalar_outputs_excl_interest - scalar_outputs_incl_interest, 1 / normalization)
188195

189196
else:
190-
# Extract scalarized output corresponding to original input/empty subset
191-
scalar_output_orig = scalar_outputs[0].item()
197+
# Extract scalarized output(s) corresponding to original input/empty subset
198+
scalar_output_orig = scalar_outputs[[0]].cpu().numpy()
192199
# Extract scalarized outputs for this unit of interest
193200
scalar_outputs_excl_interest = scalar_outputs[idx_excl_interest].cpu().numpy()
194201
scalar_outputs_incl_interest = scalar_outputs[idx_incl_interest].cpu().numpy()
195202
# Prepend output corresponding to empty subset
196-
scalar_outputs_excl_interest = np.append(scalar_output_orig, scalar_outputs_excl_interest)
203+
scalar_outputs_excl_interest = np.append(scalar_output_orig, scalar_outputs_excl_interest, axis=0)
197204

198205
# 9) Compute Shapley values
199206
normalization = get_normalization_constants(len(idx_replace_i[i]), max_units_replace) * (max_units_replace + 1)
200-
importance_scores[idx_interest[i]] = np.inner(scalar_outputs_excl_interest - scalar_outputs_incl_interest, 1 / normalization)
207+
importance_scores[idx_interest[i]] = np.dot(1 / normalization, scalar_outputs_excl_interest - scalar_outputs_incl_interest)
201208

202209
# 10) Construct output dictionary
203210
if type(importance_scores) is not dict:
204211
# Convert importance_scores to dictionary
212+
if num_output_units == 1:
213+
importance_scores = importance_scores.squeeze(axis=1)
205214
if isinstance(self.scalarized_model, ProbScalarizedModel):
206215
# Label scores with type of scalarizer
207216
importance_scores = {"prob": importance_scores}

icx360/algorithms/mexgen/mexgen.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from icx360.algorithms.lbbe import LocalBBExplainer
1313
from icx360.utils.model_wrappers import GeneratedOutput, HFModel
1414
from icx360.utils.scalarizers import ProbScalarizedModel, TextScalarizedModel
15-
from icx360.utils.segmenters import SpaCySegmenter, exclude_non_alphanumeric
15+
from icx360.utils.segmenters import SpaCySegmenter, exclude_non_alphanumeric, merge_non_alphanumeric
1616

1717

1818
class MExGenExplainer(LocalBBExplainer):
@@ -115,7 +115,8 @@ def generate_or_wrap_output(self, input_orig, output_orig=None, model_params={})
115115
input_orig (List[str]):
116116
Original input segmented into units.
117117
output_orig (str or List[str] or icx360.utils.model_wrappers.GeneratedOutput or None):
118-
Output for original input if provided, otherwise None.
118+
Output for original input.
119+
Can be a single unit (str), segmented into units (List[str]), a GeneratedOutput object, or None.
119120
model_params (dict):
120121
Additional keyword arguments for model generation (for the self.model.generate() method).
121122
@@ -130,11 +131,8 @@ def generate_or_wrap_output(self, input_orig, output_orig=None, model_params={})
130131
# Generate output for original input
131132
output_orig = self.model.generate([input_orig], text_only=False, **model_params)
132133
elif type(output_orig) in (str, list):
133-
if type(output_orig) is str:
134-
output_orig = [output_orig]
135-
136134
# Wrap output text in a GeneratedOutput object
137-
output_orig = GeneratedOutput(output_text=output_orig)
135+
output_orig = GeneratedOutput(output_text=[output_orig])
138136

139137
if isinstance(self.model, HFModel):
140138
# Also include output token IDs for HFModel
@@ -145,3 +143,31 @@ def generate_or_wrap_output(self, input_orig, output_orig=None, model_params={})
145143
raise TypeError("output_orig must be a str, List[str], GeneratedOutput, or None.")
146144

147145
return output_orig
146+
147+
def segment_output(self, output_orig, segment_type_output=None, max_phrase_length=10):
148+
"""
149+
Segment output text (if needed).
150+
151+
Args:
152+
output_orig (icx360.utils.model_wrappers.GeneratedOutput):
153+
Object containing output for original input, in particular output text (output_orig.output_text).
154+
segment_type_output (str or None):
155+
Type of units to segment into: "s" for sentences, "ph" for phrases, None for no segmentation.
156+
max_phrase_length (int):
157+
Maximum phrase length in terms of spaCy tokens (default 10).
158+
159+
Returns:
160+
output_orig (icx360.utils.model_wrappers.GeneratedOutput):
161+
Output object with possibly segmented text.
162+
"""
163+
if type(output_orig.output_text[0]) is str and segment_type_output is not None:
164+
# Output text not already segmented and segmentation requested, call segmenter
165+
output_orig.output_text[0], _, _ = self.segmenter.segment_units(output_orig.output_text[0],
166+
unit_types="p",
167+
segment_type=segment_type_output,
168+
max_phrase_length=max_phrase_length)
169+
170+
# Merge non-alphanumeric units into adjacent units
171+
output_orig.output_text[0] = merge_non_alphanumeric(output_orig.output_text[0])
172+
173+
return output_orig

0 commit comments

Comments
 (0)