@@ -32,7 +32,7 @@ class CLIME(MExGenExplainer):
3232 based on the model's inputs or outputs.
3333 """
3434 def explain_instance (self , input_orig , unit_types = "p" , output_orig = None ,
35- ind_segment = True , segment_type = "s" , max_phrase_length = 10 ,
35+ ind_segment = True , segment_type = "s" , max_phrase_length = 10 , segment_type_output = None ,
3636 model_params = {}, scalarize_params = {},
3737 oversampling_factor = 10 , max_units_replace = 2 , empty_subset = True , replacement_str = "" ,
3838 num_nonzeros = None , debias = True ):
@@ -51,14 +51,18 @@ def explain_instance(self, input_orig, unit_types="p", output_orig=None,
5151 "n" for not to be perturbed/attributed to.
5252 If str, applies to all units in input_orig, otherwise unit-specific.
5353 output_orig (str or List[str] or icx360.utils.model_wrappers.GeneratedOutput or None):
54- [output] Output for original input if provided, otherwise None.
54+ [output] Output for original input.
55+ Can be a single unit (str), segmented into units (List[str]), a GeneratedOutput object, or None.
5556 ind_segment (bool or List[bool]):
5657 [segmentation] Whether to segment input text.
5758 If bool, applies to all units; if List[bool], applies to each unit individually.
5859 segment_type (str):
5960 [segmentation] Type of units to segment into: "s" for sentences, "w" for words, "ph" for phrases.
6061 max_phrase_length (int):
6162 [segmentation] Maximum phrase length in terms of spaCy tokens (default 10).
63+ segment_type_output (str or None):
64+ [segmentation] Type of units to segment output text into:
65+ "s" for sentences, "ph" for phrases, None for no segmentation.
6266 model_params (dict):
6367 Additional keyword arguments for model generation (for the self.model.generate() method).
6468 scalarize_params (dict):
@@ -101,6 +105,8 @@ def explain_instance(self, input_orig, unit_types="p", output_orig=None,
101105
102106 # 2) Generate output for original input or wrap provided output
103107 output_orig = self .generate_or_wrap_output (input_orig , output_orig , model_params )
108+ # Segment output text if needed
109+ output_orig = self .segment_output (output_orig , segment_type_output , max_phrase_length )
104110
105111 # 3) Enumerate subsets of units that will be perturbed/replaced
106112 idx_replace = (np .array (unit_types ) != "n" ).nonzero ()[0 ]
@@ -130,7 +136,7 @@ def explain_instance(self, input_orig, unit_types="p", output_orig=None,
130136 coef [key ], intercept [key ], num_nonzeros_out [key ] = fit_linear_model (features , target [key ].cpu ().numpy (), subset_weights , num_nonzeros , debias )
131137
132138 else :
133- # Single target vector
139+ # Single target array (could contain multiple columns)
134140 coef , intercept , num_nonzeros_out = fit_linear_model (features , target .cpu ().numpy (), subset_weights , num_nonzeros , debias )
135141
136142 # 8) Construct output dictionary
@@ -186,8 +192,8 @@ def fit_linear_model(features, target, sample_weights, num_nonzeros, debias):
186192 Args:
187193 features ((num_perturb, num_units) np.ndarray):
188194 Feature values.
189- target ((num_perturb,) np.ndarray):
190- Target values to predict.
195+ target ((num_perturb,) or (num_perturb, num_output_units) np.ndarray):
196+ Target values to predict (one column for each output unit) .
191197 sample_weights ((num_perturb,) np.ndarray):
192198 Sample weights.
193199 num_nonzeros (int or None):
@@ -196,51 +202,74 @@ def fit_linear_model(features, target, sample_weights, num_nonzeros, debias):
196202 Refit linear model with no penalty after selecting features.
197203
198204 Returns:
199- coef ((num_units,) np.ndarray):
200- Coefficients of linear model.
201- intercept (float):
202- Intercept of linear model.
203- num_nonzeros (int):
204- Actual number of non-zero coefficients.
205+ coef ((num_units,) or (num_units, num_output_units) np.ndarray):
206+ Coefficients of linear model(s) (one per output unit) .
207+ intercept (float or (num_output_units,) np.ndarray ):
208+ Intercept(s) of linear model(s) (one per output unit) .
209+ num_nonzeros (List[ int] ):
210+ Actual numbers of non-zero coefficients.
205211 """
206212 num_units = features .shape [1 ]
213+ # Promote target array to 2D if needed
214+ target = target [:, None ] if target .ndim == 1 else target
215+ num_output_units = target .shape [1 ]
207216
208217 if num_nonzeros is None :
209218 # Fit dense linear model over the units that were perturbed (`active`)
210219 active = features .any (axis = 0 ).nonzero ()[0 ]
211- coef = np .zeros (num_units )
220+ coef = np .zeros (( num_units , num_output_units ) )
212221 lr = LinearRegression ()
213222 lr .fit (features [:, active ], target , sample_weight = sample_weights )
214- coef [active ] = lr .coef_
223+ coef [active , : ] = lr .coef_ . T
215224 intercept = lr .intercept_
216225
217226 else :
218227 # Fit sparse linear model
219228
220229 # Center feature and target values
221230 features_mean = features .mean (axis = 0 )
222- target_mean = target .mean ()
231+ target_mean = target .mean (axis = 0 )
223232 features_centered = features - features_mean
224233 target_centered = target - target_mean
225234
226- # Call lars_path to obtain sparse linear model with num_nonzeros coefficients
227- # NOTE: may return fewer than num_nonzeros if coefficients leave the active set
228- alphas , active , coef = lars_path (np .sqrt (sample_weights )[:, None ] * features_centered , np .sqrt (sample_weights ) * target_centered , max_iter = num_nonzeros , method = "lasso" , return_path = False )
229-
230- if debias :
231- coef = np .zeros (num_units )
232- if len (active ):
233- # Refit linear model on selected features with no penalty
234- lr = LinearRegression ()
235- lr .fit (features [:, active ], target , sample_weight = sample_weights )
236- coef [active ] = lr .coef_
237- intercept = lr .intercept_
235+ # Initialize outputs
236+ coef = np .zeros ((num_units , num_output_units ))
237+ intercept = np .zeros (num_output_units )
238+ active = [None ] * num_output_units
239+
240+ # Iterate over output units
241+ for u in range (num_output_units ):
242+ # Call lars_path to obtain sparse linear model with num_nonzeros coefficients
243+ # NOTE: may return fewer than num_nonzeros if coefficients leave the active set
244+ alphas , active [u ], coef [:, u ] = lars_path (np .sqrt (sample_weights )[:, None ] * features_centered ,
245+ np .sqrt (sample_weights ) * target_centered [:, u ],
246+ max_iter = num_nonzeros ,
247+ method = "lasso" ,
248+ return_path = False )
249+
250+ if debias :
251+ coef [:, u ] = np .zeros (num_units )
252+ if len (active [u ]):
253+ # Refit linear model on selected features with no penalty
254+ lr = LinearRegression ()
255+ lr .fit (features [:, active [u ]], target [:, u ], sample_weight = sample_weights )
256+ coef [active [u ], u ] = lr .coef_
257+ intercept [u ] = lr .intercept_
258+ else :
259+ # No active set, coefficients all zero
260+ intercept [u ] = target_mean [u ]
238261 else :
239- # No active set, coefficients all zero
240- intercept = target_mean
241- else :
242- # Compute intercept to account for centering
243- intercept = target_mean - coef @ features_mean
244-
262+ # Compute intercept to account for centering
263+ intercept [u ] = target_mean [u ] - coef [:, u ] @ features_mean
264+
265+ if num_output_units == 1 :
266+ coef , intercept = coef .squeeze (axis = 1 ), intercept .squeeze ()
267+ # Actual number(s) of non-zero coefficients
268+ if type (active [0 ]) is int :
269+ # Single active set (single list of indices) so number of non-zeros is same for all output units
270+ num_nonzeros = [len (active )] * num_output_units
271+ else :
272+ # Multiple active sets, one for each output unit
273+ num_nonzeros = map (len , active )
245274 # Negate coefficients so that important units have positive coefficients
246- return - coef , intercept , len ( active )
275+ return - coef , intercept , num_nonzeros
0 commit comments