Comments (10)
@JoelCanteroGCO can you please send a full reproducible notebook example, and I can see if there is some way to fix it?
If you try to run the KernelExplainer or DeepExplainer directly, I wonder what error you will encounter.
Also, what version of tensorflow are you using? The latest version of shap has some issues with the latest tensorflow for DeepExplainer, so it may be that you will only be able to run KernelExplainer. In that case, it is just a question of why wrap_model function in ml wrappers repo is not working correctly.
from interpret-community.
Thanks @imatiach-msft, I attach you a reproducible notebook example.
I tried to use Kernel and Deep Explainers but the error is the same:
Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 13 array(s), for inputs
.
I am using TensorFlow 2.1, but I have tried with 2.7 and it throws the same error.
from interpret-community.
@JoelCanteroGCO thanks, I took a quick look, will look into this more this week. It seems related to the dict_values type (dict(x_test).values()) which seems isn't currently handled correctly - but I will look into it more to see how it can be fixed.
from interpret-community.
@JoelCanteroGCO I was able to make it work by defining the model as:
class WrappedModel(object):
def __init__(self, model):
self._model = model
def predict(self, dataset):
# Convert the data to multiple inputs
return self._model.predict(list(dict(dataset).values()))
wrapped_model = WrappedModel(model)
and then this worked:
from interpret.ext.blackbox import TabularExplainer
explainer = TabularExplainer(wrapped_model, x_train, features=boston.feature_names, model_task='regression')
global_explanation = explainer.explain_global(x_test)
Note it is using the KernelExplainer in this case.
I am still trying to figure out how I can just automatically solve this for the user so no wrapper needs to be written, but it's hard for me to understand how this case could be auto-detected and auto-resolved.
from interpret-community.
I'm wondering now if perhaps this conversion logic should live in the DatasetWrapper in ml-wrappers repository
from interpret-community.
@JoelCanteroGCO one thing that I don't understand at all, when I call predict on these two types of data I get completely different results in your notebook - why is that?
from interpret-community.
I would guess that something is really wrong with calling it like this:
model.predict(list(dict(x_train).values()))
as it seems that the predictions are way off.
Calling on the tf dataset directly seems to give much better results, considering what the labels look like:
from interpret-community.
If I understand correctly, this may actually be the better way to wrap the model:
class WrappedModel(object):
def __init__(self, model):
self._model = model
def predict(self, dataset):
# Convert the data to multiple inputs
inp = dict(dataset)
inp_ds = tf.data.Dataset.from_tensor_slices(inp).batch(32)
return self._model.predict(inp_ds)
as it seems the other case gives strange output results, but I am not sure as to why. Note I am using latest tensorflow version, 2.8.0. Perhaps older versions behave differently.
from interpret-community.
Hello @imatiach-msft
Thank you very much for your response and your attention.
As you have noticed, when we call predict on this type of data:
model.predict(list(dict(x_train).values()))
the results are completely different. This is because when we convert the dictionary to a list, the order of the features could change and do not correspond to the order of the input layers. For this reason, call it in that way is wrong. Thanks for letting me know about that :)
Another right way to call predict is creating a dictionary of the column names with tf.constant values:
inp = {col : tf.constant(x_train[col]) for col in x_train.columns}
model.predict(inp)
Thank you @imatiach-msft very much for your solution, it seems that wrapping the model and redefining the predict function works :)
It would be awesome if the package could solve this for the user so no inheritance needs to be written.
from interpret-community.
closing issue as it should now be resolved with PR:
#510
and the ml-wrappers PR it depended on:
microsoft/ml-wrappers#39
BatchDataset should now be supported in the DatasetWrapper and all interpretability algorithms should work with it automatically.
from interpret-community.
Related Issues (20)
- Calibrating a classifier HOT 2
- WrappedClassificationModel() usage HOT 4
- ImportError: cannot import name 'TabularExplainer' HOT 3
- Issue when initializing explainer through TabularExplainer and KernelExplainer HOT 1
- Log scale combo box for Dependence plot's x axis HOT 2
- Replace load_boston with alternate regression dataset
- 'Expecting data to be a DMatrix object, got: ', <class 'pandas.core.frame.DataFrame'> HOT 7
- Question. How is the global explanation measured? HOT 7
- Question. How good is my surrogate model? HOT 8
- Calculate r2_score for PFIExplainer HOT 6
- Interpret explainer module question HOT 7
- Code Formatting Standards HOT 5
- TabularExplainer doesn't work with bias-mitigated model from fairlearn HOT 5
- How can I use MimicExplainer with Voting Classifier? [Question] HOT 5
- Dimension errors when using sklearn OneHotEncoder with min_frequency parameter HOT 1
- How can I get the data points of Aggregate feature importance ?
- Converting from NumPy array to list in mimic_explainer.py:_save()
- cannot import MimicExplainer HOT 7
- package `interpret-community` is incompatible with current `shap` version
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from interpret-community.