Git Product home page Git Product logo

circuitsvis's Introduction

CircuitsVis

Release NPMJS Pypi

Mechanistic Interpretability visualizations, that work both in both Python (e.g. with Jupyter Lab) and JavaScript (e.g. React or plain HTML).

View them all at https://transformerlensorg.github.io/CircuitsVis

Use

Install

Python

pip install circuitsvis

React

yarn add circuitsvis

Add visualizations

You can use any of the components from the demo page. These show the source code for use with React, and for Python you can instead import the function with the same name.

# Python Example
from circuitsvis.tokens import colored_tokens
colored_tokens(["My", "tokens"], [0.123, -0.226])
// React Example
import ColoredTokens from "circuitsvis";

function Example() {
    <ColoredTokens
        tokens=["My", "tokens"]
        values=[0.123, -0.266]
    />
}

Contribute

Development requirements

DevContainer

For a one-click setup of your development environment, this project includes a DevContainer. It can be used locally with VS Code or with GitHub Codespaces.

Manual setup

To create new visualizations you need Node (including yarn) and Python (with Poetry).

Once you have these, you need to install both the Node & Python packages (note that for Python we use the Poetry package management system).

cd react && yarn
cd python && poetry install --with dev

Jupyter install

If you want Jupyter as well, run poetry install --with jupyter or, if this fails due to a PyTorch bug on M1 MacBooks, run poetry run pip install jupyter.

Creating visualizations

React

You'll first want to create the visualisation in React. To do this, you can copy the example from /react/src/examples/Hello.tsx. To view changes whilst editing this (in Storybook), run the following from the /react/ directory:

yarn storybook

Python

This project uses Poetry for package management. To install run:

poetry install

Once you've created your visualization in React, you can then create a short function in the Python library to render it. You can see an example in /python/circuitsvis/examples.py.

Note that this example will render from the CDN, unless development mode is specified. Your visualization will only be available on the CDN once it has been released to the latest production version of this library.

Publishing a new release

When a new GitHub release is created, the codebase will be automatically built and deployed to PyPI.

Citation

Please cite this library as:

@misc{cooney2023circuitsvis,
    title = {CircuitsVis},
    author = {Alan Cooney and Neel Nanda},
    year = {2023},
    howpublished = {\url{https://github.com/TransformerLensOrg/CircuitsVis}},
}

circuitsvis's People

Contributors

alan-cooney avatar andyrdt avatar colah avatar danbraunai avatar dependabot[bot] avatar dkamm avatar luciaquirke avatar neelnanda-io avatar nelhage avatar oliverbalfour avatar ufo-101 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

circuitsvis's Issues

No way to unlock focus on `cv.attention.attention_patterns`?

I'm looking at https://arena-ch1-transformers.streamlit.app/[1.2]_Intro_to_Mech_Interp / https://colab.research.google.com/drive/1w9zCWpE7xd1sDuMT_rsjARfFozeWiKF4, in particular the "Visualising Attention Heads" section, with the code

print(type(gpt2_cache))
attention_pattern = gpt2_cache["pattern", 0, "attn"]
print(attention_pattern.shape)
gpt2_str_tokens = gpt2_small.to_str_tokens(gpt2_text)

print("Layer 0 Head Attention Patterns:")
display(cv.attention.attention_patterns(
    tokens=gpt2_str_tokens,
    attention=attention_pattern,
    attention_head_names=[f"L0H{i}" for i in range(12)],
))

And it seems like once I click on a head and/or token to lock the focus, there's no way to unlock the focus and get back the averaged value. There should be a way to do this, and the visualization should signpost this.

(Also, "Tokens (click to focus)" should probably be "Tokens (hover to focus, click to lock)" much like "Head selector (hover to focus, click to lock)")

support for bidirectional attention

Currently, attention plots always mask the upper triangular region.

This behavior makes sense for models using causal attention, which seems to be most models. But I've encountered a few models that use bidirectional attention (for example, the toy model here), and I think it would be helpful to be able to visualize these attention patterns using the same tooling.

A simple solution is to add an optional boolean flag to the attention_pattern and attention_heads functions that toggles whether or not to mask the upper triangular region. It could be called mask_upper_tri and would of course default to true.

Can implement quickly if other folks think it'd be useful.

Need to manually import each visualization

In the latest version, as I understand it, I need to do import circuitsvis.attention to get the attention pattern visualizer, etc for each visualization. Previously, I could just do import circuitsvis to get all of them. Is this intentional? Feels like a worse user experience IMO.

devcontainer fails to install nvidia-cublas-cu11

When building the dev container locally (Mac M1), I get:

  RuntimeError

  Unable to find installation candidates for nvidia-cuda-nvrtc-cu11 (11.7.99)

  at ~/.local/share/pypoetry/venv/lib/python3.10/site-packages/poetry/installation/chooser.py:103 in choose_for
       99│ 
      100│             links.append(link)
      101│ 
      102│         if not links:
    → 103│             raise RuntimeError(f"Unable to find installation candidates for {package}")
      104│ 
      105│         # Get the best link
      106│         chosen = max(links, key=lambda link: self._sort_key(package, link))
      107│ 

  • Installing nvidia-cuda-runtime-cu11 (11.7.99): Failed

I'm guessing the build works on the github codespace servers with GPUs, but it would be nice not to require a GPU to build.

Support for torch 2

Google Colaboratory currently has torch 2.0.0+cu118 installed by default. !pip install circuitsvis triggers a downgrade of torch to 1.13.1, which in turn downgrades nvidia_cuda and many other packages and slows setup considerably.

Is there much involved in supporting the later torch version - possibly dependent on the python version like numpy?

Add pre-built react scripts to python lib (remove CDN dependency)

  • Remove the dependency on the CDN, by building the scripts and adding to the python package in the CD
  • Sync versions of the python and node modules
  • Set the tests to run on both
  • Set the html dump from python to still use the CDN, but this time get the correct version

Colored Text is Hidden for Jupyter Notebooks

image

As shown in the image, the hovertip is hidden in jupyter notebooks.

It works as expected in Google Colab. text_neuron_activations() also works in both colab & jupyter notebooks.

I installed with:
pip install circuitsviz
OS: Linux
Python version: 3.10.9

Props cannot be NumPy floats

convert_props({"a": np.float32(1.)}) gives the error TypeError: Object of type float32 is not JSON serializable

This is an issue if you take eg max_value=array.max() for some numpy array, as this returns a np.float32 object not float. This can likely be fixed with a quick hack to cast types like this to a float. I may add this myself at some point.

Local development injects entire tensorflow.js code

When coding locally, html.local_src injects a very large amount of code each time I render an HTML. This mostly comes from injecting the source code of underlying libraries rather than a CDN (eg tensorflow.js). This makes my notebooks much larger, and makes the visualizations take several seconds to load. I'm always coding online, so I'd like to have the option to have the CDN for source code of libraries, but to put in the source code for the local code that I am using.

Remove outdated workaround specifying package requirements

https://github.com/alan-cooney/CircuitsVis/blob/c264c944204f955245209019376190f0d9242516/python/pyproject.toml#L34

...introduced a bug-fix for compatibility with pytorch 2.1. . This breaks compatibility with pytorch 2.2. Could you remove the bugfix? See error message:

ERROR: Cannot install -r requirements.txt (line 1) and -r requirements.txt (line 8) because these package versions have conflicting dependencies.

The conflict is caused by:
    circuitsvis 1.43.2 depends on nvidia-nccl-cu12==2.18.1; platform_system == "Linux" and platform_machine == "x86_64"
    torch 2.2.2 depends on nvidia-nccl-cu12==2.19.3; platform_system == "Linux" and platform_machine == "x86_64"
    circuitsvis 1.43.2 depends on nvidia-nccl-cu12==2.18.1; platform_system == "Linux" and platform_machine == "x86_64"
    torch 2.2.1 depends on nvidia-nccl-cu12==2.19.3; platform_system == "Linux" and platform_machine == "x86_64"
    circuitsvis 1.43.2 depends on nvidia-nccl-cu12==2.18.1; platform_system == "Linux" and platform_machine == "x86_64"
    torch 2.2.0 depends on nvidia-nccl-cu12==2.19.3; platform_system == "Linux" and platform_machine == "x86_64"

Dev container fails to build

When trying to build the dev container through docker desktop on MacOS with M1, I get:

  • Installing ipython (7.34.0)
  • Installing jupyter-server (1.23.0)
  • Installing psutil (5.9.4): Failed

  RuntimeError

  Unable to find installation candidates for psutil (5.9.4)

  at /usr/local/lib/python3.10/site-packages/poetry/installation/chooser.py:103 in choose_for
       99│ 
      100│             links.append(link)
      101│ 
      102│         if not links:
    → 103│             raise RuntimeError(f"Unable to find installation candidates for {package}")
      104│ 
      105│         # Get the best link
      106│         chosen = max(links, key=lambda link: self._sort_key(package, link))
      107│ 

[18769 ms] postCreateCommand failed with exit code 1. Skipping any further user-provided commands.
Done. Press any key to close the terminal.

Not sure of cause of this since psutil installs fine with pip. Temp fix is to export to a requirements.txt and use pip :):

poetry export --without-hashes --format=requirements.txt > requirements.txt && pip install -r requirements.txt

Add utils to generate boilerplate

When developing a new feature, there's an annoying amount of boiler plate I need to write - creating mock data in a json format, creating a stories.tsx file, writing the Props and function definition, and writing the Python function. I think this can mostly be automated, and makes it lower friction to add features. Eg adding a generate_stub_files function to utils. Here's an attempt:

    # Split the string into a list of words
    words = s.split('_')

    # Capitalize the first letter of each word and join them
    camel = ''.join([word.capitalize() for word in words])

    # Return the first letter in lowercase
    return camel[0].lower() + camel[1:]

def snake_to_pascal(s: str) -> str:
    # Split the string into a list of words
    words = s.split('_')

    # Capitalize the first letter of each word and join them
    pascal = ''.join([word.capitalize() for word in words])
    return pascal
# %%
# Copy and paste to mock file
mock_data = {
    "prompt": prompt,
    "top_k_log_probs": top_log_probs.tolist(),
    "top_k_tokens": top_tokens,
    "correct_tokens": correct_tokens,
    "correct_token_rank": correct_token_rank,
    "correct_token_log_prob": correct_token_log_prob.tolist(),
}
vis_name = "LogProbVis"

mock_types = {
    "prompt": "string[]",
    "top_k_log_probs": "number[][]",
    "top_k_tokens": "string[][]",
    "correct_tokens": "string[]",
    "correct_token_rank": "number[]",
    "correct_token_log_prob": "number[]",
}

print(mock_data)

# %%
s = []
for name in mock_data:
    data = mock_data[name]
    typ = mock_types[name]
    if isinstance(data, torch.Tensor):
        data = data.tolist()
    print(f"export const {snake_to_camel('mock_'+name)}: {typ} = {data};")
    print()
# %%
newline = "\n"
template = f"""import {{ ComponentStory, ComponentMeta }} from "@storybook/react";
import React from "react";
import {{ {", ".join(map(lambda name: snake_to_camel("mock_" + name), mock_data.keys()))} }} from "./mocks/{vis_name[0].lower() + vis_name[1:]}";
import { {vis_name} } from "./{vis_name}";

export default {{
  component: {vis_name}
}} as ComponentMeta<typeof {vis_name}>;

const Template: ComponentStory<typeof {vis_name}> = (args) => (
  <{vis_name} {{...args}} />
);

export const SmallModelExample = Template.bind({{}});
SmallModelExample.args = {{
  {f",{newline}  ".join([f"{snake_to_camel(name)}: {snake_to_camel('mock_'+name)}" for name in mock_data])}
}};
"""
print(template)

# %%
func_defn = f"""
export function {vis_name}({{
  {f",{newline}  ".join(map(snake_to_camel, mock_data.keys()))}
}}: {vis_name}Props) {{
"""

interface = f"""
export interface {vis_name}Props {{
{''';
  /**
   */
   '''.join([f"{snake_to_camel(name)}: {mock_types[name]}" for name in mock_types])}
}}
"""
print(func_defn)
print()
print()
print()
print(interface)```

RenderedHTML always generates local_src, is slow

When I define a CircuitsVis component, it can take several seconds as it generates local_src, and especially running a bundle_source command. I often just want the cdn_src when I am not actively developing CircuitsVis (but using it elsewhere), but the components don't let me explicitly turn this off.

I may fix this myself at some point, but assigning to @alan-cooney in case there's a principled way of doing this - IIRC you used to have it another way and changed your mind?

yarn buildBrowser --dev error

When I run any circuits vis render (eg circuitsvis.examples.hello("Neel")) I get the following error

---------------------------------------------------------------------------
CalledProcessError                        Traceback (most recent call last)
/tmp/ipykernel_1001422/1929088876.py in <module>
      1 import circuitsvis.examples
----> 2 circuitsvis.examples.hello("Help")

~/CircuitsVis/python/circuitsvis/examples.py in hello(name)
     16     return render(
     17         "Hello",
---> 18         name=name,
     19     )

~/CircuitsVis/python/circuitsvis/utils/render.py in render(react_element_name, **kwargs)
    170         Html: HTML for the visualization
    171     """
--> 172     local_src = render_local(react_element_name, **kwargs)
    173     cdn_src = render_cdn(react_element_name, **kwargs)
    174     return RenderedHTML(local_src, cdn_src)

~/CircuitsVis/python/circuitsvis/utils/render.py in render_local(react_element_name, **kwargs)
    103     if REACT_DIR.exists():
    104         install_if_necessary()
--> 105         bundle_source()
    106 
    107     # Load the JS

~/CircuitsVis/python/circuitsvis/utils/render.py in bundle_source(dev_mode)
     81                    capture_output=True,
     82                    text=True,
---> 83                    check=True
     84                    )
     85 

/opt/conda/lib/python3.7/subprocess.py in run(input, capture_output, timeout, check, *popenargs, **kwargs)
    510         if check and retcode:
    511             raise CalledProcessError(retcode, process.args,
--> 512                                      output=stdout, stderr=stderr)
    513     return CompletedProcess(process.args, retcode, stdout, stderr)
    514 

CalledProcessError: Command '['yarn', 'buildBrowser', '--dev']' returned non-zero exit status 1.

Logit visualiser

A logit visualizer, which shows text coloured by the prob/log-prob predicted for that token, and if you hover over a token it shows the top 5 logits/probs/log-probs (I was planning on trying to hack this together with your library this evening, thus my questions)

Create single attention pattern viewer

  • Show positive and negative values with different colors
  • Color blind friendly
  • Zoom functionality to see tokens (and highlight just one source/destination)

Different key and query tokens?

Example use case: we have 3 tokens at the end of a prompt, and we want to see the attention probs from those back to all other tokens in the sequence. This could be done via something like

cv.attention.attention_patterns(
    attention = attention,
    src_tokens = tokens,
    dest_tokens = tokens[-3:],
)

Not sure how difficult this would be to implement.

500 Internal Server Error from unpkg

Version 1.38.0:
str(cv.attention.attention_heads(tokens=str_tokens, attention=attention_pattern)) produces HTML containing https://unpkg.com/[email protected]/dist/cdn/esm.js. This URL returns HTTP 500 Internal Server Error as of today (I think it was working yesterday).

Variants of that URL, e.g. https://unpkg.com/circuitsvis or https://unpkg.com/[email protected]/dist/cdn/esm.js, also return an Internal Server Error.

Other unpkg links work, e.g. https://unpkg.com/[email protected]/umd/react.production.min.js.

(I love this library by the way!)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.