Git Product home page Git Product logo

sage's Introduction

SAGE

SAGE (Shapley Additive Global importancE) is a game-theoretic approach for understanding black-box machine learning models. It quantifies each feature's importance based on how much predictive power it contributes, and it accounts for complex feature interactions using the Shapley value.

SAGE was introduced in this paper, but if you're new to using Shapley values you may want to start by reading this blog post.

Install

The easiest way to get started is to install the sage-importance package with pip:

pip install sage-importance

Alternatively, you can clone the repository and install the package in your Python environment as follows:

git clone https://github.com/iancovert/sage.git
cd sage
pip install .

Usage

SAGE is model-agnostic, so you can use it with any kind of machine learning model (linear models, GBMs, neural networks, etc). All you need to do is set up an imputer to handle held out features, and then estimate the Shapley values:

import sage

# Get data
x, y = ...
feature_names = ...

# Get model
model = ...

# Set up an imputer to handle missing features
imputer = sage.MarginalImputer(model, x[:128])

# Set up an estimator
estimator = sage.PermutationEstimator(imputer, 'mse')

# Calculate SAGE values
sage_values = estimator(x, y)
sage_values.plot(feature_names)

The result will look like this:

Our implementation supports several features to make estimating the Shapley values easier:

  • Uncertainty estimation: confidence intervals are provided for each feature's importance value.
  • Convergence detection: convergence is determined based on the size of the confidence intervals, and a progress bar displays the estimated time until convergence.
  • Model conversion: our back-end requires models to be represented in a consistent format, and this conversion step is performed automatically for XGBoost, CatBoost, LightGBM, sklearn and PyTorch models. If you're using a different kind of model, it needs to be converted to a callable function (see here for examples).

Examples

Check out the following notebooks to get started:

  • Bike: a simple example using XGBoost, shows how to calculate SAGE values and Shapley Effects (an alternative explanation when no labels are available)
  • Credit: generate explanations using a surrogate model to approximate the conditional distribution (using CatBoost)
  • Airbnb: calculate SAGE values with grouped features (using a PyTorch MLP)
  • Bank: a model monitoring example that uses SAGE to identify features that hurt the model's performance (using CatBoost)
  • MNIST: shows strategies to accelerate convergence for datasets with many features (feature grouping, different imputing setups)
  • Consistency: verifies that our various Shapley value estimators return the same results (see the estimators listed below)
  • Calibration: verifies that SAGE's confidence intervals are representative of the uncertainty across runs
  • Losses: shows how SAGE can be used in classification with alternative loss functions.

If you want to replicate the experiments described in our paper, see this separate repository.

More details

This repository provides some flexibility in how you generate explanations. You can make several choices when generating explanations.

1. Feature removal approach

The original SAGE paper proposes marginalizing out missing features using their conditional distribution. Since this is challenging to implement in practice, several approximations are available. For example, you can:

  1. Use default values for missing features (see MNIST for an example). This is a fast but low-quality approximation.
  2. Sample features from the marginal distribution (see Bike for an example). This approximation is discussed in the SAGE paper.
  3. Train a supervised surrogate model (see Credit for an example). This approach is described in this paper, and it can provide a better approximation than the other approaches. However, it requires training an additional model (typically a neural network).
  4. Train a model that accommodates missingness. This approach is not shown here, but it's described in this paper.

2. Explanation type

Two types of explanations can be calculated, both based on Shapley values:

  1. SAGE. This approach quantifies how much each feature improves the model's performance (this is the default).
  2. Shapley Effects. Described in this paper, this explanation method quantifies the model's sensitivity to each feature. Since Shapley Effects is closely related to SAGE (see here for details), our implementation generates this type of explanation when labels are not provided. See the Bike notebook for an example.

3. Shapley value estimator

Shapley values are computationally costly to calculate exactly, so we provide several estimation approaches:

  1. Permutation sampling. This is the approach described in the original paper (see PermutationEstimator).
  1. KernelSAGE. This is a linear regression-based estimator that's similar to KernelSHAP (see KernelEstimator). It's described in this paper, and the Bank notebook shows an example.
  2. Iterated sampling. This is a variation on the permutation sampling approach where we calculate Shapley values sequentially for each feature (see IteratedEstimator). This enables faster convergence for features with low variance, but it can result in wider confidence intervals.
  3. Sign estimation. This method estimates SAGE values to a lower precision by focusing only on their sign (i.e., whether they help or hurt performance). It's implemented in SignEstimator, and the Bank notebook shows an example.

The results from each approach should be identical (see Consistency), but there may be differences in convergence speed. Permutation sampling is a good approach to start with. KernelSAGE may converge a bit faster, but the uncertainty is spread more evenly among the features rather than being highest for more important features.

4. Grouped features

Rather than removing features individually, you can specify groups of features to be removed jointly. This will likely speed up convergence because there are fewer feature subsets. See Airbnb for an example.

Authors

References

Ian Covert, Scott Lundberg, Su-In Lee. "Understanding Global Feature Contributions With Additive Importance Measures." NeurIPS 2020

Ian Covert, Scott Lundberg, Su-In Lee. "Explaining by Removing: A Unified Framework for Model Explanation." JMLR 2021

Ian Covert, Su-In Lee. "Improving KernelSHAP: Practical Shapley Value Estimation via Linear Regression." AISTATS 2021

Art Owen. "Sobol' Indices and Shapley value." SIAM 2014

sage's People

Contributors

andresalgaba avatar denizy7 avatar dependabot[bot] avatar iancovert avatar j-adamczyk avatar karelze avatar pre-commit-ci[bot] avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

sage's Issues

pip install

Hi,
I would like to create a pull request to push my setup.py file to allow a simplified installation with pip.

Could you allow pull requests?

Unstable SAGE values

Hi,

I ran SAGE on my test data which is 5k rows and 500 columns. The SAGE values that I get are very small (~0.003) and change direction and magnitude with subsequent runs. I also tried with both permutation sampling and iterated sampling however the difference is not very clear. Could you help me understand how to resolve this issue and how n_samples might affect my results.

Zero-One Loss in Classification

Thanks for this awesome library. 💯

In #7 you discussed the use of alternative loss functions in classification.

I'm working on a use case, where I perform classification with different classifiers, but I mainly care about the accuracy of the prediction, rather than the predicted probabilities, as some classifiers yield hard probabilities only. As such, I wanted to swap the cross-entropy loss for the zero-one loss.

I extended the utils.py (see here.) and added the new loss function:

class ZeroOneLoss:
    '''zero-one loss that expects probabilities.'''
    def __init__(self, reduction='mean'):
        assert reduction in ('none', 'mean')
        self.reduction = reduction

    def __call__(self, pred, target):

        # Add a dimension to prediction probabilities if necessary.
        if pred.ndim == 1:
            pred = pred[:, np.newaxis]
        if pred.shape[1] == 1:
            pred = np.append(1 - pred, pred, axis=1)

        if target.ndim == 1:
        # Class labels.
            loss = (np.argmax(pred, axis=1) != target).astype(float)
        elif target.ndim == 2:
        # Probabilistic labels.
            loss = (np.argmax(pred, axis=1) != np.argmax(target, axis=1)).astype(float)
        else:
            raise ValueError('incorrect labels shape for zero-one loss')

        if self.reduction == 'mean':
            return np.mean(loss)
        return loss

and call it like this:

imputer = sage.MarginalImputer(model, train)
estimator = sage.KernelEstimator(imputer, 'zero one')

I was wondering, if there is more to consider, when using alternative loss functions, in particular zero-one-loss, with SAGE?
Would you also be interested in a PR?

Parallelized computation

First of all, many thanks for the amazing package and very clear paper and blogpost indicated in the readme. Please correct me if I am mistaken, but it looks like you are not using multiprocessing for parallel computation (I am thinking joblib) of the terms in the shapley values. Is there a technical reason for this? If not, are you planning to add it in the future?

SAGE values on cross-validation

Hello! On my dataset SAGE values depend quite a lot on the train-test split. Would it be correct to average the SAGE values means and stds on cross-validation?

Possibility to use presegmented images

Hello everyone,
I wonder if it is possible to use pre-segmented images. In my case, I would like to explain models trained using MRI brain scans. Additionally, I have segmentation for each scan that represent different brain structures. The positions of those structures differed across scans, the number of areas remains stable. Now, I would like to compute the SAGE importance for each brain structure. Many thanks in advance.

Mismatch between feature importancies from `GroupedMarginalImputer` and `MarginalImputer`

@iancovert Thanks for your work on SAGE.

Recently, I've experimented with a use case, where I needed to obtain global feature importances for groups of features and individual features. I used sage.GroupedMarginalImputer(model, test[:64], groups)/ sage.MarginalImputer(model, test[:64]) in conjunction with sage.PermutationEstimator(imputer, 'mse', n_jobs=-1).

My initial assumption was, that the importances of a feature group should roughly match the summed importance of the individual features within the group as long as background samples are identical/groups do not overlap etc.. However, I noticed relatively large differences, which I could not explain.

I was able to reproduce this behaviour with a slightly altered version of the airbnb notebook:
https://colab.research.google.com/gist/KarelZe/67d2e4445a8971c72a36fd846977270a/airbnb.ipynb

Example from notebook for group with multiple features:

df_grouped.loc["location (grouped)"]
> 0    2136.538846
> Name: location (grouped), dtype: float64
df_individual.loc[['latitude', 'longitude', 'neighbourhood', 'neighbourhood_group']].sum()
> 0    2372.024399
> dtype: float64

Example from notebook for group with one feature:

df_grouped.loc["room_type"]
> 0    2597.069309
> Name: room_type, dtype: float64
df_individual.loc["room_type"]
> 0    2623.990063
> Name: room_type, dtype: float64

In my private datasets, I also had more severe cases, where individual features had a high sage value, but the accompanying group had not and vice versa.

What I'd like to know is, if is fair to assume, that GroupedMarginalImputer and MarginalImputer yield the same results (in particular if groups consist only of one feature; second example)? It would also be helpful to know, where the differences in sage values between both imputation strategies comes from.

SAGE with NLP/Huggingface

Hello - thanks for the fantastic library.

My project involves predicting a continuous measure from text, and I'm currently using average SHAP value for global token importance. I think SAGE is better suited for this, so I was kindly wondering if there will be (or is) support for huggingface NLP models. Thanks!

All negative SAGE values in classification

Dear @iancovert,

I perform binary classification using CatBoost and try to determine with GroupedMarginalImputer and PermutationEstimator, but I only obtain negative SAGE values.

My test set X_test is relatively large (9861576, 15)I provide the imputer with a random sample X_importance of the test set of size (512, 15) similar to the notebook.
My labels are [-1,1], if relevant.

My code

 # ...
    # use callable as catboost is already with pool incl.  categoricals
    def call_catboost(X):
        if feature_str == "ml":       
            X = pd.DataFrame(X, columns=X_importance.columns)
            # Update the selected columns in the original DataFrame
            X[cat_features] = X.iloc[:, cat_idx].astype(int)
            # pass cat indices
            return clf.predict_proba(Pool(X, cat_features=cat_idx))
        else:
            return clf.predict_proba(X) # <- used here
    
    # apply group based imputation + estimate importances in terms of cross-entropy
    imputer = GroupedMarginalImputer(call_catboost, X_importance, groups)
    estimator = PermutationEstimator(imputer, "cross entropy")
    
    # calculate values over entire test set
    sage_values = estimator(X_test.values, y_test.values)
    
    # save sage values + std deviation to data frame
    result = pd.DataFrame(index=group_names, data={"values": sage_values.values, "std": sage_values.std})

Obtained Result:
512 background samples:
image

256 background samples:
image

As visible in the screenshots, all SAGE values are negative. I noticed that if I decrease the subset passed to GroupedMarginalImputer to only 256 samples one group is slightly positive. Uncertainties are relatively low.

To my understanding such a result would imply, that all features (or groups) contribute negatively to the loss. This is somewhat counter-intuitive to me, as the classifier has a strong performance on the test set. I've seen the notebooks / examples in the SAGE paper, where single corrupted features degrade performance, but not every feature.

Is there some implementation error on my side e.g., misunderstanding of background samples or is such a result plausible?

Thanks for your help.

PS:
A similar issue was discussed in #2, but it seems to be resolved.

PPS:
I also ran the experiment with my implementation of the zero one loss (see #18) but the output is similar.

Explanation about new changes in the SAGE package and addition of Model sensitivity module.

Hi @iancovert , I was among the very first users of SAGE and I found it very helpful for my use cases. However I see that a lot has changed since the first release both in terms of code and in terms of application.

Could you provide an updated Readme file for this new release and also explain about the model sensitivity module that you have added.

I would also like to point out that the run times for SAGE on an 8gb Ram machine for 20k rows and 500 columns is a bit high and is potentially a road block for its wide usage. Is there a possibility that the run time can be optimized ?

TreeSAGE ?

Hi,

Do you think is it possible to adapt the SAGE approach using a faster approximation for tree-like models like the TreeSHAP algorithm for shapley values ?

License

Hi Ian.

I came across this package today from your very nice paper on the arXiv "Understanding Global Feature Contributions Through Additive Importance Measures." I am excited to try out this code, but I don't see a license. Would you be willing to clarify what the license is and add a LICENSE.txt file?

Thanks!

Shape mismatch on XGB.Classifier

Hi. I hope and suspect this is just a misunderstanding on my part, but I cannot mage sage work using the XGBoost Classifer.
I have included a minimal example below, using the Boston Housing dataset. Here, the XGBRegressor works fine but the XGBClassifier produces an error, although y_pred has the same shape in both cases.
The error* arises in utils.py in call(self, pred, target), line 159, probably because the if clause on line 156 isn't triggered so pred doesn't get reshaped (?). I'm not sure but would appreciate your input. And thanks for a great package :-)

*ValueError: operands could not be broadcast together with shapes (512,2) (512,)

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import xgboost as xgb
import sage
boston = load_boston()
boston_dataset = pd.DataFrame(boston.data, columns=boston.feature_names)
boston_dataset['MEDV'] =  boston.target
features = ["RM", "AGE", "TAX", "CRIM", "PTRATIO"]
x_data = np.array(boston_dataset[features])
medv = np.array(boston_dataset.MEDV)
mean = np.mean(medv)
y_data = np.array([1 if _m < mean else 0 for _m in medv]) # make targets binary
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.33, random_state=42)
#model = xgb.XGBClassifier(use_label_encoder=False).fit(x_train, y_train) # Doesn't work
model = xgb.XGBRegressor().fit(x_train, y_train) # Works
y_pred = model.predict(x_test)
print(y_pred.shape)
imputer = sage.MarginalImputer(model, x_train[:512])
estimator = sage.PermutationEstimator(imputer, 'mse')
sage_values = estimator(x_test, y_test)
sage_values.plot(features)

Exception encountered when calling layer "gru" (type GRU).

I am getting the below error when trying to use it with text data with GRU Layer.

InternalError: Exception encountered when calling layer "gru" (type GRU).
Failed to call ThenRnnForward with model config: [rnn_mode, rnn_input_mode, rnn_direction_mode]: 3, 0, 0 , [num_layers, input_size, num_units, dir_count, max_seq_length, batch_size, cell_num_units]: [1, 64, 64, 1, 32, 250000, 0] [Op:CudnnRNN]

Call arguments received:
• inputs=tf.Tensor(shape=(250000, 32, 64), dtype=float32)
• mask=None
• training=False
• initial_state=None

Model:
Model: "sequential"

Layer (type) Output Shape Param #
embedding (Embedding) (None, 32, 64) 768000
spatial_dropout1d (SpatialD ropout1D) (None, 32, 64) 0
gru (GRU) (None, 32, 64) 24960
dropout (Dropout) (None, 32, 64) 0
gru_1 (GRU) (None, 64) 24960
dropout_1 (Dropout) (None, 64) 0
dense (Dense) (None, 32) 2080
dropout_2 (Dropout) (None, 32) 0
dense_1 (Dense) (None, 100) 3300
dense_2 (Dense) (None, 1) 101

Total params: 823,401
Trainable params: 823,401
Non-trainable params: 0

PermutationEstimator runs infinitely when gap = 0

This is a bug report.

Using GRP with certain datasets would lead the PermutationEstimator to run indefinitely (or what feels to be infinite), and it'd keep running into a divide by zero warning.

Here's an excerpt of the code where the bug happened

import sage
from sklearn.gaussian_process import GaussianProcessRegressor
...

gpr = GaussianProcessRegressor(...)
gpr.fit(X_train, y_train)

imputer   = sage.MarginalImputer(gpr, X_train)
estimator = sage.PermutationEstimator(imputer, "mse")
importance = estimator(X_, y_, bar=False)

and here's the error that keeps being thrown, which fills up stderr

sage/permutation_estimator.py:126: RuntimeWarning: divide by zero encountered in double_scalars
    ratio = std / gap

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.