Git Product home page Git Product logo

birkhoffg / relax Goto Github PK

View Code? Open in Web Editor NEW
1.0 1.0 1.0 16.06 MB

Recourse Explanation Library in JAX

Home Page: https://birkhoffg.github.io/ReLax/

License: Apache License 2.0

Jupyter Notebook 68.58% Python 31.20% CSS 0.17% SCSS 0.05%
counterfactual-explanations explainable-ai algorithmic-recourse jax benchmarking explainability explainability-libraries python cpu research-tool gpu tpu recourse jax-relax

relax's Introduction

ReLax

Python CI status Docs pypi GitHub License

Overview | Installation | Tutorials | Documentation | Citing ReLax

Important

๐Ÿ“ฃ This repository is migrated to a new link: https://github.com/BirkhoffG/jax-relax.

Overview

ReLax (Recourse Explanation Library in Jax) is a library built on top of jax to generate counterfactual and recourse explanations for Machine Learning algorithms. By leveraging vectorization though vmap/pmap and just-in-time compilation in jax (a high-performance auto-differentiation library). ReLax offers massive speed improvements in generating individual (or local) explanations for predictions made by Machine Learning algorithms.

Some of the key features are as follows:

  • ๐Ÿƒ Fast recourse generation via jax.jit, jax.vmap/jax.pmap.

  • ๐Ÿš€ Accelerated over cpu, gpu, tpu.

  • ๐Ÿช“ Comprehensive set of recourse methods implemented for benchmarking.

  • ๐Ÿ‘ Customizable API to enable the building of entire modeling

  • and interpretation pipelines for new recourse algorithms.

Installation

The latest ReLax release can directly be installed from PyPI:

pip install jax-relax

or installed directly from the repository:

pip install git+https://github.com/BirkhoffG/ReLax.git 

To futher unleash the power of accelerators (i.e., GPU/TPU), we suggest to first install this library via pip install jax-relax. Then, follow steps in the official install guidelines to install the right version for GPU or TPU.

An Example of using ReLax

See Getting Started with ReLax.

Citing ReLax

To cite this repository:

@software{relax2023github,
  author = {Hangzhi Guo and Xinchang Xiong and Amulya Yadav},
  title = {{R}e{L}ax: Recourse Explanation Library in Jax},
  url = {http://github.com/birkhoffg/ReLax},
  version = {0.1.0},
  year = {2023},
}

relax's People

Contributors

birkhoffg avatar grahams-uncle avatar

Watchers

 avatar

Forkers

grahams-uncle

relax's Issues

Support hyper-parameter searching for CF explanation methods

Supporting hyper-parameter searching enables us to properly benchmark the algorithms. This issue is a thread discussing how to support hyperparameter searching in CF explanation methods.

In essence, this problem is a multi-objective problem (i.e., minimizing invalidity and cost).

Some open-sourced libraries of hyper-parameter searching:

Support aux arguments of `pred_fn` to be passed to `generate_cf_explanations`

Currently, we assume pred_fn is a function of only one input x. E.g., it is something like:

pred_fn = lambda x: 2 * x + 1

However, it is possible that user-defined pred_fn takes other arguments.

Hence, I propose

def generate_cf_explanations(
    cf_module: BaseCFModule,
    datamodule: TabularDataModule,
    pred_fn: Callable[[jnp.DeviceArray], jnp.DeviceArray] = None,
    *,
    t_configs=None,
    pred_fn_args: dict=None
)

where inside, we call pred_fn as

pred_fn(x, **pred_fn_args)

This offers additional flexibility for models that are not implemented using our framework.

CI/CD takes too long

Seems to run some unnecessary tests (e.g., train some models) during the testing

Get rid of the Pytorch dependencies in `TabularDataModule`

Pytorch is only needed for loading data. Our library mainly handles tabular data, so data loading would not be a bottleneck to most scenarios. Pytorch Dataloader is overkill for our project in most use cases.

Purpose

Write a drop-in NumpyLoader.

ToDo

Delete the Pytorch Dependency

https://github.com/BirkhoffG/cfnet/blob/24783713dc787cd5b13e70aa483e455c4856198f/settings.ini#L18

https://github.com/BirkhoffG/cfnet/blob/24783713dc787cd5b13e70aa483e455c4856198f/cfnet/datasets.py#L9

Next, modify the following code to make them not inherent Pytorch Dataset and DataLoader:

https://github.com/BirkhoffG/cfnet/blob/24783713dc787cd5b13e70aa483e455c4856198f/cfnet/datasets.py#L12-L22

https://github.com/BirkhoffG/cfnet/blob/24783713dc787cd5b13e70aa483e455c4856198f/cfnet/datasets.py#L35-L51

Expected Functionalities

NumpyDataset should contain all the input data.

# x, y are jax.numpy.array, such that len(x) == len(y)
dataset = NumpyDataset(x, y)

x, y = dataset[:] # access all the data of x, y
x_5, y_5 = dataset[:5] # access first five data of x, y

NumpyLoader iterates the NumpyDataset. See Pytorch Docs.

batch_size = 128
dataloader = NumpyLoader(
    dataset, # a `NumpyDataset`
    batchsize=batch_size,
    shuffle=True, # if True, shuffle the data; else, return the data in order
    drop_last=False # if True, discard the last batch (if len(dataset) % batchsize != 0); else, return the last batch
)

for x, y in dataloader:
    assert len(x) == batch_size
    assert len(y) == batch_size
    ...

Refactor util functions

Move

  • binary_cross_entropy in cfnet.methods.vanilla
  • grad_update, cat_normalize in cfnet.training_module

into cfnet.utils

Pass `seed` and `batch_size` to the dataloader functions in `TabularDataModule`

  1. Pass seed and batch_size to TabularDataModule.train_dataloader, TabularDataModule.val_dataloader, and TabularDataModule.test_dataloader.

  2. batch_size should also be an argument in TrainingConfigs
    https://github.com/BirkhoffG/cfnet/blob/2ee1a3203a9935e89b2ed8adf175ee1150fd9960/cfnet/train.py#L15

  3. Deprecated batch_size in DataConfigs

  4. Finally, pass appropriate arguments:
    https://github.com/BirkhoffG/cfnet/blob/2ee1a3203a9935e89b2ed8adf175ee1150fd9960/cfnet/train.py#L58-L59

Customize DataModule-dependent constraints

We use cat_normalize for encoding features, and clip continuous features to [0, 1]. This is because we use one-hot encoding for cat features, and min-max scalar for cont features.

If a user wants to use other encoding methods (e.g., standardized cont features), our current way of handing normalized data is not applicable.

Proposed features:

Provide Default Data Configs for `TabularDataModule`

Proposal

data_module = TabularDataModule('adult')

As such, TabularDataModule will automatically load data_configs of the adult dataset.

We should also allow TabularDataModule to pass user-defined configs (i.e., current argument data_configs: str | dict).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.