julia-xai / explainableai.jl Goto Github PK
View Code? Open in Web Editor NEWExplainable AI in Julia.
License: MIT License
Explainable AI in Julia.
License: MIT License
Add PassRule
that implements Rₖ .= Rₖ₊₁
.
This should speed things up a lot!
Support for nested Chain
s was introduced in #119, but canonization currently still requires flattening the model.
The types of primitives currently end with the word Rule
and could be mistaken with LRP rules.
Add a Shapley
analyzer when ShapML is loaded. LazyModules.jl could be used for this purpose.
Refer to Zennit implementation.
This will be the first step towards #10 by allowing nested model structures.
This requires the following changes:
LRPRulesChain
/ LRPRulesParallel
ChainTuple
and ParallelTuple
] to support graphslrp!(Rₖ, r::AbstractLRPRule, c::Chain, aₖ, Rₖ₊₁)
, which can be called recursively. This will require a different approach to pre-allocating activations and relevances than the one currently used in the call to the analyzer.Composite
might require a refactoringlrp/show.jl
Run benchmarks on a representative model, e.g. VGG16.
Add function canonize(model)
which merges BatchNorm layers into Dense and Conv layers with linear activation functions.
GPU support is currently untested. In theory, GPU tests could be run on CI using the JuliaGPU Buildkite CI.
Locally, a first test of GPU support can be run by modifying the readme example to cast input
to a GPU array.
It should be possible to run the following code in a fresh temp-environment:
using CUDA
using ExplainableAI
using Flux
using MLDatasets
using Downloads: download
using BSON: @load
model_url = "https://github.com/adrhill/ExplainableAI.jl/raw/master/docs/src/model.bson"
path = joinpath(@__DIR__, "model.bson")
!isfile(path) && download(model_url, path)
@load "model.bson" model
model = strip_softmax(model)
x, _ = MNIST.testdata(Float32, 10)
input = reshape(x, 28, 28, 1, :)
input_gpu = gpu(input) # cast input to GPU array
analyzer = LRP(model)
expl = analyze(input_gpu, analyzer)
This should be handled via external packages DataAugmentations.jl.
Since this is a breaking change, it should be implemented before a 1.0 release.
This issue is used to trigger TagBot; feel free to unsubscribe.
If you haven't already, you should update your TagBot.yml
to include issue comment triggers.
Please see this post on Discourse for instructions and more details.
If you'd like for me to do this for you, comment TagBot fix
on this issue.
I'll open a PR within a few hours, please be patient!
This issue keeps track of which Flux layers in the model reference got LRP implementations.
Dense
flatten
Conv
DepthwiseConv
ConvTranspose
CrossCor
AdaptiveMaxPool
MaxPool
GlobalMaxPool
AdaptiveMeanPool
MeanPool
GlobalMeanPool
normalise
BatchNorm
#129dropout
Dropout
AlphaDropout
LayerNorm
InstanceNorm
GroupNorm
Upsample
PixelShuffle
RNN
LSTM
GRU
Recur
skip_checks=true
.The goal should be to make ExplainabilityMethods transparent but extendable.
Loading VGG and preprocessing the input currently takes up a lot a space in the docs.
Update README to reflect #157.
CIFAR10 in the docs and readme should be more interesting than MNIST.
Replace Zygote dependency with AbstractDifferentiation.jl for backend agnostic AD.
Add the Threads.@threads
macro to the for-loop in gradients_wrt_batch
:
https://github.com/adrhill/ExplainableAI.jl/blob/f1b89ab9c784ff2a86de9997efa904580b2af6dd/src/gradient.jl#L5-L16
Using loops over pre-allocated arrays or StackViews.jl should speed up things and help with type stability.
All gradient analyzers currently use mapreduce
:
https://github.com/adrhill/ExplainableAI.jl/blob/8641dfb101b63ed5a9876b32988c25e0c6b2191d/src/gradient.jl#L5-L13
Methods using InputAugmentation
would also benefit from refactoring:
https://github.com/adrhill/ExplainableAI.jl/blob/8641dfb101b63ed5a9876b32988c25e0c6b2191d/src/input_augmentation.jl#L74-L83
Refer to quickstart example: https://lorenzoh.github.io/DataAugmentation.jl/dev/docs/literate/quickstart.md.html
This would remove the dependency on Flux and make the package lighter for users that don't require LRP.
Since this requires a breaking release, this is a milestone for a future 1.0 release.
Update field names.
The field layerwise_relevances
of the Explanation
struct is too specific to LRP.
An extras
field of type Union{Nothing, Dict}
would be more flexible.
Could be implemented as the one-liner ZPlusRule() = AlphaBetaRule(1.0f0, 0.0f0)
,
but a large part of the computation could be skipped since β=0
.
for more performant use of Tullio.jl.
Zygote's gradient
appears to be non-deterministic on Metalhead's VGG19:
julia> a = gradient((in) -> model(in)[1], imgp)[1];
julia> b = gradient((in) -> model(in)[1], imgp)[1];
julia> isapprox(a, b; atol=1e-3)
false
julia> isapprox(a, b; atol=1e-2)
true
Check whether this is due to Dropout layers or Zygote.
ColorSchemes 3.18 adds :centered
option to ColorSchemes.get
that makes _normalize
redundant.
A visualisation method such as plot_text_heatmap
from this iNNvestigate notebook would be useful for language tasks such as sentiment analysis.
EpsilonGammaBox
EpsilonPlus
EpsilonAlpha2Beta1
EpsilonPlusFlat
EpsilonAlpha2Beta1Flat
This would close #79 and be a major milestone towards a 1.0 release.
Wrap attribution in struct containing metadata such as used analyzer to dispatch on heatmap
.
This struct could also contain the neuron selection and possibly the classifier output.
Similar to Flux's printing of Chains.
and add tests on LRP rules.
Using keys obtained by chainkey
. A user-friendly version of chainkey
must be exported and documented.
For a more efficient implementation of LRP rules.
Metalhead disabled pretrained weights in 0.6.0
due to model inaccuracies.
These can technically still be loaded while they are being fixed:
model = VGG19()
Flux.loadparams!(model.layers, weights("vgg19"))
However, the VGG19 weights are a 548 MB download every time CI is run. It might therefore be more reasonable to use a smaller model. Currently, MetalheadWeights contains (in ascending size):
Parallel
for "fire" modulesParallel
SkipConnection
Parallel
, skip_identity
An easy workaround would be to run the methods on randomly initialized parameters (with fixed seed). The explanations w.r.t. to this model should still stay constant.
Add doc pages for:
Also:
Currently, the following things can be improved or are missing documentation:
NoiseAugmentation
, InterpolationAugmentation
LayerMap
and show_layer_indices
introduced in #131flatten
and performance benefitsThis should return a , which should also be supported by Vector{Explanation}
heatmap
.
Automatic differentiation can be skipped for most known layers.
Add tests for a Dense
wrapper to still test AD fallback.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.