Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UFuncTypeError when explaining lightgbm model #314

Open
CharlesCousyn opened this issue Feb 5, 2025 · 4 comments
Open

UFuncTypeError when explaining lightgbm model #314

CharlesCousyn opened this issue Feb 5, 2025 · 4 comments

Comments

@CharlesCousyn
Copy link
Contributor

I have a LightGBM classifier working well.
But when i execute:

explainer  = shapiq.TreeExplainer(
    model=model.named_steps["classifier"], # My LightGBMClassifier
    max_order= 1,
    min_order= 1,
    index="k-SII",
    class_index=0
)
interaction_values = explainer.explain(x[42])
print(interaction_values)

I got this error:

---------------------------------------------------------------------------
UFuncTypeError                            Traceback (most recent call last)
Cell In[16], line 1
----> 1 interaction_values = explainer.explain(x[42])
      2 print(interaction_values)

File ~/Code/isa-ml/isa-ml-jobs/.venv/lib/python3.10/site-packages/shapiq/explainer/_base.py:99, in Explainer.explain(self, x, *args, **kwargs)
     88 def explain(self, x: np.ndarray, *args, **kwargs) -> InteractionValues:
     89     """Explain a single prediction in terms of interaction values.
     90 
     91     Args:
   (...)
     97         The interaction values of the prediction.
     98     """
---> 99     explanation = self.explain_function(x=x, *args, **kwargs)
    100     if explanation.min_order == 0:
    101         explanation[()] = explanation.baseline_value

File ~/Code/isa-ml/isa-ml-jobs/.venv/lib/python3.10/site-packages/shapiq/explainer/tree/explainer.py:93, in TreeExplainer.explain_function(self, x, **kwargs)
     91 interaction_values: list[InteractionValues] = []
     92 for explainer in self._treeshapiq_explainers:
---> 93     tree_explanation = explainer.explain(x)
     94     interaction_values.append(tree_explanation)
     96 # combine the explanations for all trees

File ~/Code/isa-ml/isa-ml-jobs/.venv/lib/python3.10/site-packages/shapiq/explainer/tree/treeshapiq.py:151, in TreeSHAPIQ.explain(self, x)
    147 self.shapley_interactions = np.zeros(
    148     int(sp.special.binom(self._n_features_in_tree, order)), dtype=float
    149 )
    150 self._prepare_variables_for_order(interaction_order=order)
--> 151 self._compute_shapley_interaction_values(x_relevant, order=order, node_id=0)
    152 # append the computed Shapley Interaction values to the result
    153 interactions = np.append(interactions, self.shapley_interactions.copy())

File ~/Code/isa-ml/isa-ml-jobs/.venv/lib/python3.10/site-packages/shapiq/explainer/tree/treeshapiq.py:281, in TreeSHAPIQ._compute_shapley_interaction_values(self, x, order, node_id, summary_poly_down, summary_poly_up, interaction_poly_down, quotient_poly_down, depth)
    276     summary_poly_up[depth] = (
    277         summary_poly_down[depth] * self._edge_tree.empty_predictions[node_id]
    278     )
    279 else:  # not a leaf -> continue recursion
    280     # left child
--> 281     self._compute_shapley_interaction_values(
    282         x,
    283         order=order,
    284         node_id=left_child,
    285         summary_poly_down=summary_poly_down,
    286         summary_poly_up=summary_poly_up,
    287         interaction_poly_down=interaction_poly_down,
    288         quotient_poly_down=quotient_poly_down,
    289         depth=depth + 1,
    290     )
    291     summary_poly_up[depth] = (
    292         summary_poly_up[depth + 1] * self.D_powers[current_height - left_height]
    293     )
    294     # right child

File ~/Code/isa-ml/isa-ml-jobs/.venv/lib/python3.10/site-packages/shapiq/explainer/tree/treeshapiq.py:295, in TreeSHAPIQ._compute_shapley_interaction_values(self, x, order, node_id, summary_poly_down, summary_poly_up, interaction_poly_down, quotient_poly_down, depth)
    291     summary_poly_up[depth] = (
    292         summary_poly_up[depth + 1] * self.D_powers[current_height - left_height]
    293     )
    294     # right child
--> 295     self._compute_shapley_interaction_values(
    296         x,
    297         order=order,
    298         node_id=right_child,
    299         summary_poly_down=summary_poly_down,
    300         summary_poly_up=summary_poly_up,
    301         interaction_poly_down=interaction_poly_down,
    302         quotient_poly_down=quotient_poly_down,
    303         depth=depth + 1,
    304     )
    305     summary_poly_up[depth] += (
    306         summary_poly_up[depth + 1] * self.D_powers[current_height - right_height]
    307     )
    309 # if node is not the root node -> calculate the Shapley Interaction values for the node

File ~/Code/isa-ml/isa-ml-jobs/.venv/lib/python3.10/site-packages/shapiq/explainer/tree/treeshapiq.py:281, in TreeSHAPIQ._compute_shapley_interaction_values(self, x, order, node_id, summary_poly_down, summary_poly_up, interaction_poly_down, quotient_poly_down, depth)
    276     summary_poly_up[depth] = (
    277         summary_poly_down[depth] * self._edge_tree.empty_predictions[node_id]
    278     )
    279 else:  # not a leaf -> continue recursion
    280     # left child
--> 281     self._compute_shapley_interaction_values(
    282         x,
    283         order=order,
    284         node_id=left_child,
    285         summary_poly_down=summary_poly_down,
    286         summary_poly_up=summary_poly_up,
    287         interaction_poly_down=interaction_poly_down,
    288         quotient_poly_down=quotient_poly_down,
    289         depth=depth + 1,
    290     )
    291     summary_poly_up[depth] = (
    292         summary_poly_up[depth + 1] * self.D_powers[current_height - left_height]
    293     )
    294     # right child

File ~/Code/isa-ml/isa-ml-jobs/.venv/lib/python3.10/site-packages/shapiq/explainer/tree/treeshapiq.py:281, in TreeSHAPIQ._compute_shapley_interaction_values(self, x, order, node_id, summary_poly_down, summary_poly_up, interaction_poly_down, quotient_poly_down, depth)
    276     summary_poly_up[depth] = (
    277         summary_poly_down[depth] * self._edge_tree.empty_predictions[node_id]
    278     )
    279 else:  # not a leaf -> continue recursion
    280     # left child
--> 281     self._compute_shapley_interaction_values(
    282         x,
    283         order=order,
    284         node_id=left_child,
    285         summary_poly_down=summary_poly_down,
    286         summary_poly_up=summary_poly_up,
    287         interaction_poly_down=interaction_poly_down,
    288         quotient_poly_down=quotient_poly_down,
    289         depth=depth + 1,
    290     )
    291     summary_poly_up[depth] = (
    292         summary_poly_up[depth + 1] * self.D_powers[current_height - left_height]
    293     )
    294     # right child

File ~/Code/isa-ml/isa-ml-jobs/.venv/lib/python3.10/site-packages/shapiq/explainer/tree/treeshapiq.py:244, in TreeSHAPIQ._compute_shapley_interaction_values(self, x, order, node_id, summary_poly_down, summary_poly_up, interaction_poly_down, quotient_poly_down, depth)
    242 # if node is not a leaf -> set activations for children nodes accordingly
    243 if not is_leaf:
--> 244     if x[child_edge_feature] <= feature_threshold:
    245         activations[left_child], activations[right_child] = True, False
    246     else:

UFuncTypeError: ufunc 'less_equal' did not contain a loop with signature matching types (<class 'numpy.dtypes.Float64DType'>, <class 'numpy.dtypes.StrDType'>) -> None

I search a little and it seems that the self._tree.thresholds contains some string values like (3||7||9||11||14||17||18||28||31||49||53||63||68||70||72||73||74||75||82||84||90||96||107||108||112||114||124||128||129||130||135||139):
Image

I don't know why yet and i will look later.

@CharlesCousyn
Copy link
Contributor Author

It seems linked to the categorical support in lightgbm: shap/shap#170

@CharlesCousyn
Copy link
Contributor Author

My quick and dirty fix would be like this:

if isinstance(feature_threshold, str):
    try:
        # Parse the threshold string into a list of allowed category codes
        allowed_categories = [float(val) for val in feature_threshold.split("||")]
    except Exception as e:
        raise ValueError(f"Could not parse categorical threshold: {feature_threshold}") from e
    # For categorical splits, use membership testing:
    if x[child_edge_feature] in allowed_categories:
        activations[left_child], activations[right_child] = True, False
    else:
        activations[left_child], activations[right_child] = False, True
else:
    # For numerical splits, use the regular threshold comparison:
    if x[child_edge_feature] <= feature_threshold:
        activations[left_child], activations[right_child] = True, False
    else:
        activations[left_child], activations[right_child] = False, True

instead of this code in shapiq/explainer/tree/treeshapiq.py:244:

if x[child_edge_feature] <= feature_threshold:
        activations[left_child], activations[right_child] = True, False
else:
        activations[left_child], activations[right_child] = False, True

What do you think? Maybe the thresholds should be handled when loading the model instead?

@mmschlk
Copy link
Owner

mmschlk commented Feb 5, 2025

I see, good point. You are right, the categorical features are the problem. I was not aware how LightGBM does this comparison, I always thought it works like sklearn in that it treats everything like numerical. It might be that I did something wrong in the conversion.

While I think your fix would work, we should think well about adding stuff to the explanation loop. I'd rather catch such things when we do the tree conversion. ... This would however be a more elaborate change. I want to think about this a bit more. :)

Does your fix work for you for the time being?

@CharlesCousyn
Copy link
Contributor Author

Yes, i confirm the dirty fix work properly!
I have my interaction shap values ! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants