Comments (15)
s = ["I enjoy walking with my cute dog"]
gen_dict = dict(
max_new_tokens=100,
num_beams=5,
renormalize_logits=True,
no_repeat_ngram_size=8,
)
model.config.task_specific_params = dict()
model.config.task_specific_params["text-generation"] = gen_dict
shap_model = shap.models.TeacherForcing(model, tokenizer)
masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)
explainer = shap.Explainer(shap_model, tokenizer)
shap_values = explainer(s)
this solution should solve Your problem with Llama-2, and with Mistral.
from shap.
Thanks for the report, I requested access for llama2 now. Will post here once there are any updates.
from shap.
Now is partially working. I get the values array all zeros
.values =
array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
.base_values =
array([[-0.52342548, -2.88410334, -0.36660795, 0.95767114, -0.24581238,
-1.7873968 , -0.03110165, 6.36779252, 0.35044359, 6.19219956,
2.58561948, 7.491578 , 3.81039219, -2.31506661, -0.42209612,
-0.62833234, -0.64307879, -1.80665099, -1.17742844, 0.66211251,
-0.27008128, 6.25576132, 4.52850612, -1.31015818, -0.45799231,
10.00574646, 0.50833428, -0.42837653, 4.29419062, 4.06030859,
-1.15709111, -0.20367953, 5.86984239, 4.13385361, 2.36138941,
-0.08768206, 3.2889124 , 0.68570033, -0.53387673, 0.55577215,
-0.35025047, 3.82609343, -0.75910988, 2.56892822, 2.24339371,
3.08884504, 0.6789584 , 0.73464042, 1.60391795, 2.63059456,
5.00190821, 7.1968913 , 1.43071471, 2.62828756, 3.7208354 ,
11.1741379 , 6.9844757 , 0.30599576, 2.32297348, 0.70061408,
1.50329472, 5.85171772, 0.6600345 , 1.56481051, 4.1168472 ,
6.36192085, -0.81794184, 2.52507464, 4.35465319, -0.32329904,
1.68587773, -1.32010292, 2.59065567]])
.data =
(array(['', 'I', ' enjoy', ' walking', ' with', ' my', ' c', 'ute', ' do
from shap.
I guess the PR #3578 fixes this
from shap.
Uhmm I saw it just now, but I'm pretty new to this and when I try to install this PR it tells me that I need to have python 3.9. Maybe I'm doing something wrong, how is the way to use this PR?
from shap.
Did not test it, but I according to SO should work:
pip install https://github.com/costrau/shap/archive/fix-transformers.zip
from shap.
Not working :((
pip install https://github.com/costrau/shap/archive/fix-transformers.zip
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting https://github.com/costrau/shap/archive/fix-transformers.zip
Downloading https://github.com/costrau/shap/archive/fix-transformers.zip
- 159.5 MB 5.7 MB/s 0:00:30
Installing build dependencies ... done
Getting requirements to build wheel ... error
error: subprocess-exited-with-error
× Getting requirements to build wheel did not run successfully.
│ exit code: 1
╰─> [119 lines of output]
Attempting to build SHAP: with_binary=True, with_cuda=True (Attempt 1)
NVCC ==> /usr/local/cuda/bin/nvcc
Compiling cuda extension, calling nvcc with arguments:
['/usr/local/cuda/bin/nvcc', '-allow-unsupported-compiler', 'shap/cext/_cext_gpu.cu', '-lib', '-o', 'build/lib_cext_gpu.a', '-Xcompiler', '-fPIC', '--include-path', '/usr/include/python3.8', '--std', 'c++14', '--expt-extended-lambda', '--expt-relaxed-constexpr', '-gencode=arch=compute_60,code=sm_60', '-gencode=arch=compute_70,code=sm_70', '-gencode=arch=compute_75,code=sm_75', '-gencode=arch=compute_75,code=compute_75', '-gencode=arch=compute_80,code=sm_80']
Exception occurred during setup, setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
WARNING: Could not compile cuda extensions.
Retrying SHAP build without cuda extension...
Attempting to build SHAP: with_binary=True, with_cuda=False (Attempt 2)
Exception occurred during setup, setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
WARNING: The C extension could not be compiled, sklearn tree models not supported.
Retrying SHAP build without binary extension...
Attempting to build SHAP: with_binary=False, with_cuda=False (Attempt 3)
Exception occurred during setup, setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
ERROR: Failed to build!
Traceback (most recent call last):
File "<string>", line 144, in try_run_setup
File "<string>", line 134, in run_setup
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/__init__.py", line 104, in setup
return distutils.core.setup(**attrs)
File "/usr/lib/python3.8/distutils/core.py", line 108, in setup
_setup_distribution = dist = klass(attrs)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 307, in __init__
_Distribution.__init__(self, dist_attrs)
File "/usr/lib/python3.8/distutils/dist.py", line 292, in __init__
self.finalize_options()
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 658, in finalize_options
ep(self)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 121, in infer_version
_assign_version(dist, config)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 56, in _assign_version
_version_missing(config)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_get_version_impl.py", line 112, in _version_missing
raise LookupError(
LookupError: setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<string>", line 144, in try_run_setup
File "<string>", line 134, in run_setup
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/__init__.py", line 104, in setup
return distutils.core.setup(**attrs)
File "/usr/lib/python3.8/distutils/core.py", line 108, in setup
_setup_distribution = dist = klass(attrs)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 307, in __init__
_Distribution.__init__(self, dist_attrs)
File "/usr/lib/python3.8/distutils/dist.py", line 292, in __init__
self.finalize_options()
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 658, in finalize_options
ep(self)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 121, in infer_version
_assign_version(dist, config)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 56, in _assign_version
_version_missing(config)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_get_version_impl.py", line 112, in _version_missing
raise LookupError(
LookupError: setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 353, in <module>
main()
File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 335, in main
json_out['return_val'] = hook(**hook_input['kwargs'])
File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 118, in get_requires_for_build_wheel
return hook(config_settings)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/build_meta.py", line 325, in get_requires_for_build_wheel
return self._get_build_requires(config_settings, requirements=['wheel'])
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/build_meta.py", line 295, in _get_build_requires
self.run_setup()
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/build_meta.py", line 311, in run_setup
exec(code, locals())
File "<string>", line 165, in <module>
File "<string>", line 160, in try_run_setup
File "<string>", line 160, in try_run_setup
File "<string>", line 144, in try_run_setup
File "<string>", line 134, in run_setup
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/__init__.py", line 104, in setup
return distutils.core.setup(**attrs)
File "/usr/lib/python3.8/distutils/core.py", line 108, in setup
_setup_distribution = dist = klass(attrs)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 307, in __init__
_Distribution.__init__(self, dist_attrs)
File "/usr/lib/python3.8/distutils/dist.py", line 292, in __init__
self.finalize_options()
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 658, in finalize_options
ep(self)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 121, in infer_version
_assign_version(dist, config)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 56, in _assign_version
_version_missing(config)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_get_version_impl.py", line 112, in _version_missing
raise LookupError(
LookupError: setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
[end of output]
note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error
× Getting requirements to build wheel did not run successfully.
│ exit code: 1
╰─> See above for output.
note: This error originates from a subprocess, and is likely not a problem with pip.
from shap.
I moved to python 3.10 and used the following structure of pip install
pip3 install git+https://github.com/costrau/shap.git@fix-transformers
It finally worked
BTW, do you know what mechanism is shap using to see attentions?
from shap.
Is this resolving your bug? What do you mean with see attentions
? For transformer models shap basically just does inference and extracts the logits from there.
from shap.
Yep, what I wrote resolved my bug.
Yeah, it extracts the logits from there, but it just shows the raw logits of the layers doing an average or something. Because a llama has 32 heads and 32 layers, it has 32x32 different attention matrices.
from shap.
Shap does not provide any information on model internals just how the model uses the input to generate the output. If you are interested in that maybe you can get some results from the pytorch/captum package. I once read a paper about saliency maps that basically does what you want, but captum's implementation of that also just seems like giving you the attributions of the inputs
from shap.
Yes, so if shap is representing how the model is using the input, it's probably using attentions, right?
Otherwise, thanks for the idea of using the captom package, I didn't know it existed, I'll look at it
from shap.
In this case shap is using the model, the model uses attention layers. But shap has no idea about model internals if you explain transformer models
from shap.
Related Issues (20)
- ENH: Faster import performance HOT 1
- ENH: Label dots in scatterplot according to classes and add a legend
- BUG: SHAP DeepExplainer cannot get SHAP values from TorchScript model HOT 1
- CIBuildWheel failing on windows runners
- ENH: Limiting number of CPU cores used by shap HOT 4
- [Meta issue] Release 0.45.0 HOT 5
- BUG: Waterfall feature names IndexError HOT 1
- BUG: DeepExplainer throws error when using `__call__`
- CI failing on tensorflow 2.16+ due to incompatibility between transformers & keras V3
- BUG: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Functio HOT 3
- ENH: Display worst features with barplot
- Allow shap to take list or dict as input HOT 3
- BUG: LightGBM with multiclass interaction TreeShap produces explainer error HOT 12
- ENH: Integrate Fasttreeshap speedup into SHAP HOT 2
- BUG: `base_score` attribute of the `XGBTreeModelLoader` is broken for all exponential losses (e.g. tweedie, poisson) HOT 4
- BUG: 0.45.0 update breaks pytorch example on docs HOT 1
- x
- BUG: Error using Falcon for text-generation HOT 4
- BUG: Error when using DeepExplainer on LSTM Model HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from shap.