Git Product home page Git Product logo

explainableai.jl's Introduction

ExplainableAI.jl


Documentation Build Status DOI
Aqua QA

Explainable AI in Julia.

This package implements interpretability methods for black-box classifiers, with an emphasis on local explanations and attribution maps in input space. The only requirement for the model is that it is differentiable1. It is similar to Captum and Zennit for PyTorch and iNNvestigate for Keras models.

Installation

This package supports Julia ≥1.6. To install it, open the Julia REPL and run

julia> ]add ExplainableAI

Example

Let's explain why an image of a castle is classified as such by a vision model:

using ExplainableAI

# Load model and input
model = ...                      # load classifier model
input = ...                      # input in batch-dimension-last format

# Run XAI method
analyzer = SmoothGrad(model)
expl = analyze(input, analyzer)  # or: analyzer(input)

# Show heatmap
heatmap(expl)

# Or analyze & show heatmap directly
heatmap(input, analyzer)

By default, explanations are computed for the class with the highest activation. We can also compute explanations for a specific class, e.g. the one at output index 5:

analyze(input, analyzer, 5)  # for explanation 
heatmap(input, analyzer, 5)  # for heatmap
Analyzer Heatmap for class "castle" Heatmap for class "street sign"
InputTimesGradient
Gradient
SmoothGrad
IntegratedGradients

Tip

The heatmaps shown above were created using a VGG-16 vision model from Metalhead.jl that was pre-trained on the ImageNet dataset.

Since ExplainableAI.jl can be used outside of Deep Learning models and Flux.jl, we have omitted specific models and inputs from the code snippet above. The full code used to generate the heatmaps can be found here.

Depending on the method, the applied heatmapping defaults differ: sensitivity-based methods (e.g. Gradient) default to a grayscale color scheme, whereas attribution-based methods (e.g. InputTimesGradient) default to a red-white-blue color scheme. Red color indicates regions of positive relevance towards the selected class, whereas regions in blue are of negative relevance. More information on heatmapping presets can be found in the Julia-XAI documentation.

Warning

ExplainableAI.jl used to contain Layer-wise Relevance Propagation (LRP). Since version v0.7.0, LRP is now available as part of a separate package in the Julia-XAI ecosystem, called RelevancePropagation.jl.

Analyzer Heatmap for class "castle" Heatmap for class "street sign"
LRP with EpsilonPlus composite
LRP with EpsilonPlusFlat composite
LRP with EpsilonAlpha2Beta1 composite
LRP with EpsilonAlpha2Beta1Flat composite
LRP with EpsilonGammaBox composite
LRP with ZeroRule (discouraged)

Video Demonstration

Check out our talk at JuliaCon 2022 for a demonstration of the package.

Methods

Currently, the following analyzers are implemented:

  • Gradient
  • InputTimesGradient
  • SmoothGrad
  • IntegratedGradients
  • GradCAM

One of the design goals of the Julia-XAI ecosystem is extensibility. To implement an XAI method, take a look at the common interface defined in XAIBase.jl.

Roadmap

In the future, we would like to include:

Contributions are welcome!

Acknowledgements

Adrian Hill acknowledges support by the Federal Ministry of Education and Research (BMBF) for the Berlin Institute for the Foundations of Learning and Data (BIFOLD) (01IS18037A).

Footnotes

  1. More specifically, models currently have to be differentiable with Zygote.jl.

explainableai.jl's People

Contributors

adrhill avatar dependabot[bot] avatar jeananness avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

explainableai.jl's Issues

Add docs

Add doc pages for:

  • Basic example on VGG model
  • Combining LRP rules
  • Custom LRP rules

Also:

  • Add images to readme

Add default composites

  • Add presets similar to those in Zennit, e.g.:
    • EpsilonGammaBox
    • EpsilonPlus
    • EpsilonAlpha2Beta1
    • EpsilonPlusFlat
    • EpsilonAlpha2Beta1Flat
  • Update README with examples from presets

Support input batches

This should return a Vector{Explanation}, which should also be supported by heatmap.

Refactor results struct

Update field names.
The field layerwise_relevances of the Explanation struct is too specific to LRP.
An extras field of type Union{Nothing, Dict} would be more flexible.

Test GPU support

GPU support is currently untested. In theory, GPU tests could be run on CI using the JuliaGPU Buildkite CI.

Locally, a first test of GPU support can be run by modifying the readme example to cast input to a GPU array.
It should be possible to run the following code in a fresh temp-environment:

using CUDA
using ExplainableAI
using Flux
using MLDatasets
using Downloads: download
using BSON: @load

model_url = "https://github.com/adrhill/ExplainableAI.jl/raw/master/docs/src/model.bson"
path = joinpath(@__DIR__, "model.bson")
!isfile(path) && download(model_url, path)
@load "model.bson" model

model = strip_softmax(model)
x, _ = MNIST.testdata(Float32, 10)
input = reshape(x, 28, 28, 1, :)

input_gpu = gpu(input) # cast input to GPU array
analyzer = LRP(model)
expl = analyze(input_gpu, analyzer)

Regression tests of all methods on VGG19

Metalhead disabled pretrained weights in 0.6.0 due to model inaccuracies.
These can technically still be loaded while they are being fixed:

model = VGG19()
Flux.loadparams!(model.layers, weights("vgg19"))

However, the VGG19 weights are a 548 MB download every time CI is run. It might therefore be more reasonable to use a smaller model. Currently, MetalheadWeights contains (in ascending size):

  • SqueezeNet (5 MB) -> requires Parallel for "fire" modules
  • GoogLeNet (27 MB) -> requires Parallel
  • Densenet121 (31 MB) -> requires SkipConnection
  • ResNet-50 (98 MB) -> requires Parallel, skip_identity
  • VGG-19 (548 MB)

An easy workaround would be to run the methods on randomly initialized parameters (with fixed seed). The explanations w.r.t. to this model should still stay constant.

LRP rule coverage for Flux layers

This issue keeps track of which Flux layers in the model reference got LRP implementations.


Basic layers

  • Dense
  • flatten

Convolution

  • Conv
  • DepthwiseConv
  • ConvTranspose
  • CrossCor

Pooling layers

  • AdaptiveMaxPool
  • MaxPool
  • GlobalMaxPool
  • AdaptiveMeanPool
  • MeanPool
  • GlobalMeanPool

General purpose

  • Maxout
  • SkipConnection
  • Chain #119
  • Parallel #10
  • Bilinear
  • Diagonal
  • Embedding

Normalisation & regularisation

  • normalise
  • BatchNorm #129
  • dropout
  • Dropout
  • AlphaDropout
  • LayerNorm
  • InstanceNorm
  • GroupNorm

Upsampling layers

  • Upsample
  • PixelShuffle

Recurrent layers

  • RNN
  • LSTM
  • GRU
  • Recur

Update documentation for `v0.6.0` release

Currently, the following things can be improved or are missing documentation:

  • input augmentations: NoiseAugmentation, InterpolationAugmentation
  • usage of LayerMap and show_layer_indices introduced in #131
  • LRP keyword flatten and performance benefits
  • LRP model canonization
  • Update section "How it works internally" for #119
  • Update "Model checks for humans" for #119
  • Use DocumenterCitations.jl
  • #64

Add model canonization

Add function canonize(model) which merges BatchNorm layers into Dense and Conv layers with linear activation functions.

TagBot trigger issue

This issue is used to trigger TagBot; feel free to unsubscribe.

If you haven't already, you should update your TagBot.yml to include issue comment triggers.
Please see this post on Discourse for instructions and more details.

If you'd like for me to do this for you, comment TagBot fix on this issue.
I'll open a PR within a few hours, please be patient!

Add `ZPlusRule`

Could be implemented as the one-liner ZPlusRule() = AlphaBetaRule(1.0f0, 0.0f0),
but a large part of the computation could be skipped since β=0.

Fix randomness in gradients

Zygote's gradient appears to be non-deterministic on Metalhead's VGG19:

julia> a = gradient((in) -> model(in)[1], imgp)[1];

julia> b = gradient((in) -> model(in)[1], imgp)[1];

julia> isapprox(a, b; atol=1e-3)
false

julia> isapprox(a, b; atol=1e-2)
true

Check whether this is due to Dropout layers or Zygote.

Reduce allocations in LRP methods

  • Instead of having rules modify layer parameters, avoid allocations by implementing modified forward calls that can be diff'ed.
  • Pre-allocate buffers
    • for activations on forward-pass
    • for relevances on backward-pass
  • For analysis of multiple output neurons, only run forward-pass once

This should speed things up a lot!

Add model checks for LRP

  • Check model for non-ReLU-like activations and unknown layers.
  • Display a summary as to why checks failed if they do, as well as references to the docs on how to fix these issues.
  • Make checks skip-able through LRP kwarg skip_checks=true.

The goal should be to make ExplainabilityMethods transparent but extendable.

Add LRP support for nested Chains

This will be the first step towards #10 by allowing nested model structures.

This requires the following changes:

  • new internal representation of rules, e.g. via LRPRulesChain / LRPRulesParallel [Edit: now called ChainTuple and ParallelTuple] to support graphs
  • treat chains the same way as layers by adding lrp!(Rₖ, r::AbstractLRPRule, c::Chain, aₖ, Rₖ₊₁), which can be called recursively. This will require a different approach to pre-allocating activations and relevances than the one currently used in the call to the analyzer.
  • Composite might require a refactoring
  • Update lrp/show.jl

Remove ImageNet preprocessing code

This should be handled via external packages DataAugmentations.jl.

  • document preprocessing with external packages

Since this is a breaking change, it should be implemented before a 1.0 release.

Move LRP into separate package

This would remove the dependency on Flux and make the package lighter for users that don't require LRP.

Since this requires a breaking release, this is a milestone for a future 1.0 release.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.