Git Product home page Git Product logo

harsanyinet's Introduction

HarsanyiNet

This repository contains the Python implementation for HarsanyiNet, "HarsanyiNet: Computing Accurate Shapley Values in a Single Forward Propagation", ICML 2023.

HarsanyiNet is an interpretable network architecture, which makes inferences on the input sample and simultaneously computes the exact Shapley values of the input variables in a single forward propagation (see papers for details and citations).

Install

HarsanyiNet can be installed in the Python 3 environment:

pip3 install git+https://github.com/csluchen/harsanyinet

The torchtoolbox package also needs to be installed:

pip3 install torchtoolbox

You may also use conda environment

conda create --name harsanyinet python=3.9
conda activate harsanyinet
pip3 install -r requirements.txt

How to use

HarsanyiNet-CNN

To train the model, you can use codes like the following:

  • CIFAR-10 dataset
python train.py
  • MNIST dataset
python train.py --dataset='MNIST' --num_layers=4 --channels=32 --beta=100 --gamma=0.05

or directly access the pre-trained HarsanyiNet in Google Drive. You can download pretrained_model.zip and unzip it into path like ./pretrained_model/{DATASET}/.../model_pths/{DATASET}.pth.

To compute Shapley values using HarsanyiNet in a single forward propagation, use codes like the following:

python shapley.py --save_path='./pretrained_model' --model_path='model_pths/CIFAR10.pth' --num_layers=10 --channels=256 --beta=1000 --gamma=1 

HarsanyiNet-MLP

Datasets

We provide implementation on three different tabular datasets from UCI repository, including

Getting Started

To get started, you can run python utils/tabular/data_preprocess.py to download and preprocess the data. The preprocessed data will be stored as annp.ndarry in data/{DATASET}/. Alternatively, you can directly use utils/data.py to load the dataloader directly, we have already incorporate this step.

To train the model, use the following code:

  • Census dataset
python train_tabular.py
  • Yeast dataset
python train_tabular.py --dataset Yeast --n_attributes 8
  • Commercial (TV News) dataset
python train_tabular.py --dataset Commercial --n_attributes 10

Note:

  • For the Census dataset, we provide the pretrained model under pretrained_model/Census.pth.
  • For the Yeast and Commercial dataset, we do not provide the pretrained models, beacuse both of the datasets don't have official data splits. We randomly split the whole dataset into 80% training data and 20% testing data.

To compute Shapley values using HarsanyiNet in a single forward propagation, use the following code:

  • Census
    • using the provided pretrained_model:
    python shapley_tabular.py --model_path pretrained_model/Census.pth
    
    • if you have trained your own model
     python shapley_tabular.py
    
  • Yeast
   python shapley_tabular.py --dataset Yeast --n_attributes 8
  • Commercial (TV News)
   python shapley_tabular.py --dataset Commercial --n_attributes 10

More details

Comparing Shapley values computed by HarsanyiNet and other methods

To compute the root mean squared error (RMSE) between the Shapley values computed by HarsanyiNet and sampling method, use the following code:

python shapley.py --sampling=True --runs=20000

Note: the larger the number of iterations (runs) of the sampling method, the more accurate the sampling method is and the longer it takes for the code to run.

To compute the RMSE between the Shapley values computed by HarsanyiNet and ground-truth Shapley values, use the following code:

python shapley.py --ground_truth=True

Sample notebooks

For image dataset, we provide a Jupyter notebook for the CIFAR-10 and MNIST dataset for calculating Shapley values via HarsanyiNet under notebooks/CIFAR-10.ipynb and notebooks/MNIST.ipynb, respectively.

For tabular dataset, we provide a Jupyter notebook for the Census dataset for calculating Shapley values via HarsanyiNet under notebooks/Census.ipynb

Citations

@InProceedings{chen23,
  title = {HarsanyiNet: Computing Accurate Shapley Values in a Single Forward Propagation},
  author = {Lu, Chen and Siyu, Lou and Keyan, Zhang and Jin, Huang and Quanshi, Zhang},
  booktitle = {Proceedings of the 40th International Conference on Machine Learning},
  year = {2023}
}

harsanyinet's People

Contributors

csluchen avatar siyulou avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.