Git Product home page Git Product logo

Comments (10)

imatiach-msft avatar imatiach-msft commented on June 18, 2024

@JoelCanteroGCO can you please send a full reproducible notebook example, and I can see if there is some way to fix it?
If you try to run the KernelExplainer or DeepExplainer directly, I wonder what error you will encounter.
Also, what version of tensorflow are you using? The latest version of shap has some issues with the latest tensorflow for DeepExplainer, so it may be that you will only be able to run KernelExplainer. In that case, it is just a question of why wrap_model function in ml wrappers repo is not working correctly.

from interpret-community.

JoelCanteroGCO avatar JoelCanteroGCO commented on June 18, 2024

Thanks @imatiach-msft, I attach you a reproducible notebook example.

I tried to use Kernel and Deep Explainers but the error is the same:
Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 13 array(s), for inputs.

I am using TensorFlow 2.1, but I have tried with 2.7 and it throws the same error.

from interpret-community.

imatiach-msft avatar imatiach-msft commented on June 18, 2024

@JoelCanteroGCO thanks, I took a quick look, will look into this more this week. It seems related to the dict_values type (dict(x_test).values()) which seems isn't currently handled correctly - but I will look into it more to see how it can be fixed.

from interpret-community.

imatiach-msft avatar imatiach-msft commented on June 18, 2024

@JoelCanteroGCO I was able to make it work by defining the model as:

class WrappedModel(object):
    def __init__(self, model):
        self._model = model

    def predict(self, dataset):
        # Convert the data to multiple inputs
        return self._model.predict(list(dict(dataset).values()))

wrapped_model = WrappedModel(model)

and then this worked:

from interpret.ext.blackbox import TabularExplainer
explainer = TabularExplainer(wrapped_model, x_train, features=boston.feature_names, model_task='regression')
global_explanation = explainer.explain_global(x_test)

Note it is using the KernelExplainer in this case.
I am still trying to figure out how I can just automatically solve this for the user so no wrapper needs to be written, but it's hard for me to understand how this case could be auto-detected and auto-resolved.

from interpret-community.

imatiach-msft avatar imatiach-msft commented on June 18, 2024

I'm wondering now if perhaps this conversion logic should live in the DatasetWrapper in ml-wrappers repository

from interpret-community.

imatiach-msft avatar imatiach-msft commented on June 18, 2024

@JoelCanteroGCO one thing that I don't understand at all, when I call predict on these two types of data I get completely different results in your notebook - why is that?
image

from interpret-community.

imatiach-msft avatar imatiach-msft commented on June 18, 2024

I would guess that something is really wrong with calling it like this:
model.predict(list(dict(x_train).values()))
as it seems that the predictions are way off.
Calling on the tf dataset directly seems to give much better results, considering what the labels look like:
image

from interpret-community.

imatiach-msft avatar imatiach-msft commented on June 18, 2024

If I understand correctly, this may actually be the better way to wrap the model:

class WrappedModel(object):
    def __init__(self, model):
        self._model = model

    def predict(self, dataset):
        # Convert the data to multiple inputs
        inp = dict(dataset)
        inp_ds = tf.data.Dataset.from_tensor_slices(inp).batch(32)
        return self._model.predict(inp_ds)

as it seems the other case gives strange output results, but I am not sure as to why. Note I am using latest tensorflow version, 2.8.0. Perhaps older versions behave differently.

from interpret-community.

JoelCanteroGCO avatar JoelCanteroGCO commented on June 18, 2024

Hello @imatiach-msft

Thank you very much for your response and your attention.

As you have noticed, when we call predict on this type of data:
model.predict(list(dict(x_train).values()))

the results are completely different. This is because when we convert the dictionary to a list, the order of the features could change and do not correspond to the order of the input layers. For this reason, call it in that way is wrong. Thanks for letting me know about that :)

Another right way to call predict is creating a dictionary of the column names with tf.constant values:

inp = {col : tf.constant(x_train[col]) for col in x_train.columns}
model.predict(inp)

Thank you @imatiach-msft very much for your solution, it seems that wrapping the model and redefining the predict function works :)

It would be awesome if the package could solve this for the user so no inheritance needs to be written.

from interpret-community.

imatiach-msft avatar imatiach-msft commented on June 18, 2024

closing issue as it should now be resolved with PR:
#510
and the ml-wrappers PR it depended on:
microsoft/ml-wrappers#39

BatchDataset should now be supported in the DatasetWrapper and all interpretability algorithms should work with it automatically.

from interpret-community.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.