Git Product home page Git Product logo

Comments (4)

mengwang-mw avatar mengwang-mw commented on June 15, 2024 1

I have encountered the same issue - with multiclass output, the summary_plot function generates interaction plot while the summary bar plot is expected.

I manually fixed this issue by going to their source code and change the data type of their TreeExplainer output from numpy array to list.

Here is what I did in detail: I went to https://github.com/shap/shap/blob/master/shap/explainers/_tree.py and commented lines 515-516. After that, I successfully generated the summary plot with multi-class output.

This error was due to the change in version 0.45.0 - they changed the output from list to numpy array, as can be seen in lines 410-411 of file https://github.com/shap/shap/blob/master/shap/explainers/_tree.py, so I reversed this change to fix the issue.

from shap.

wiktorolszowy avatar wiktorolszowy commented on June 15, 2024

It is not XGBoost-specific, as I have the same problem with SHAP values derived from CatBoost and LightGBM models. It is related to shap.summary_plot.

from shap.

wiktorolszowy avatar wiktorolszowy commented on June 15, 2024

Well spotted! I think for the majority of cases, a shortcut with a C++ implementation of Tree SHAP is used, so these 2 lines need to be commented out too (the same data transformation as in the lines you pointed to):

https://github.com/shap/shap/blob/86d8bc58a42e9e11901ad506f5c27f55fa4f0349/shap/explainers/_tree.py#L478C1-L479C49

Commenting these lines out most likely has some side effects, but without these lines the SHAP summary plot indeed works for multi-class classification models. Thanks!

from shap.

Omranic avatar Omranic commented on June 15, 2024

I encountered the same problem, and switching back to version 0.44.1 resolved it for me.

Below is a straightforward code to demonstrate the issue:

# Create a synthetic dataset
X, y = make_classification(n_samples=100, n_features=5, n_informative=3, n_redundant=1, n_clusters_per_class=1, n_classes=3, random_state=42)
features = [f"Feature {i}" for i in range(X.shape[1])]
X = pd.DataFrame(X, columns=features)

# Train a RandomForest model
model = RandomForestClassifier(n_estimators=50, random_state=42)
model.fit(X, y)

# Create the SHAP Explainer
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

# Plot SHAP values for each class
shap.summary_plot(shap_values, X, plot_type="bar", class_names=['Class 0', 'Class 1', 'Class 2'])

Here are the screenshots for both versions:

Screenshot 2024-06-10 at 11 15 50 AM copy
Screenshot 2024-06-10 at 11 32 27 AM copy

from shap.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.