Git Product home page Git Product logo

saedashboard's Introduction

SAEDashboard

This code is a detached fork of SAEVis and is a work in progress. Please bare with us while we develop it further.

TODO:

  • set up GPU CI server so we can test things like mult-GPU generation.
  • Profile code with multiple GPU's to improve efficiency.
  • Work out a way to parallelize feature generation accross jobs so we can get this all moving much faster.

OLD README

This codebase was designed to replicate Anthropic's sparse autoencoder visualisations, which you can see here. The codebase provides 2 different views: a feature-centric view (which is like the one in the link, i.e. we look at one particular feature and see things like which tokens fire strongest on that feature) and a prompt-centric view (where we look at once particular prompt and see which features fire strongest on that prompt according to a variety of different metrics).

Install with pip install sae-vis. Link to PyPI page here.

Features & Links

Important note - this repo was significantly restructured in March 2024 (we'll remove this message at the end of April). The recent changes include:

  • The ability to view multiple features on the same page (rather than just one feature at a time)
  • D3-backed visualisations (which can do things like add lines to histograms as you hover over tokens)
  • More freedom to customize exactly what the visualisation looks like (we provide full cutomizability, rather than just being able to change certain parameters)

Here is a link to a Google Drive folder containing 3 files:

  • User Guide, which covers the basics of how to use the repo (the core essentials haven't changed much from the previous version, but there are significantly more features)
  • Dev Guide, which we recommend for anyone who wants to understand how the repo works (and make edits to it)
  • Demo, which is a Colab notebook that gives a few examples

In the demo Colab, we show the two different types of vis which are supported by this library:

  1. Feature-centric vis, where you look at a single feature and see e.g. which sequences in a large dataset this feature fires strongest on.

  1. Prompt-centric vis, where you input a custom prompt and see which features score highest on that prompt, according to a variety of possible metrics.

Citing this work

To cite this work, you can use this bibtex citation:

@misc{sae_vis,
    title  = {{SAE Visualizer}},
    author = {Callum McDougall},
    howpublished    = {\url{https://github.com/callummcdougall/sae_vis}},
    year   = {2024}
}

Contributing

This project is uses Poetry for dependency management. After cloning the repo, install dependencies with poetry install.

This project uses Ruff for formatting and linting, Pyright for type-checking, and Pytest for tests. If you submit a PR, make sure that your code passes all checks. You can run all checks with make check-ci.

Version history (recording started at 0.2.9)

  • 0.2.9 - added table for pairwise feature correlations (not just encoder-B correlations)
  • 0.2.10 - fix some anomalous characters
  • 0.2.11 - update PyPI with longer description
  • 0.2.12 - fix height parameter of config, add videos to PyPI description
  • 0.2.13 - add to dependencies, and fix SAELens section
  • 0.2.14 - fix mistake in dependencies
  • 0.2.15 - refactor to support eventual scatterplot-based feature browser, fix ’ HTML
  • 0.2.16 - allow disabling buffer in feature generation, fix demo notebook, fix sae-lens compatibility & type checking
  • 0.2.17 - use main branch of sae-lens

saedashboard's People

Contributors

callummcdougall avatar hijohnnylin avatar jbloomaus avatar chanind avatar lucyfarnik avatar wllgrnt avatar arthurconmy avatar jordansauce avatar

Stargazers

 avatar  avatar Sheikh Abdur Raheem Ali avatar Tim Lawson avatar

Watchers

 avatar Curt Tigges avatar  avatar

Forkers

chanind

saedashboard's Issues

[Bug] Multi-GPU / CUDA requires a tweak

I should create a PR for this but I don't want to lose track of this in the meantime:

When running on CUDA with multi-GPU, SAEDashboard/sae_dashboard/feature_data_generator.py needs a tweak:

Line 129
Before

model_acts = self.model.forward(minibatch_tokens, return_logits=False)

After

model_acts = self.model.forward(minibatch_tokens.to("cuda"), return_logits=False)

[Bug] OOM on first batch of when generating dashboards

SAEDashboard saves cached activations to disk the first time that we generate dashboards. This allows us to just load from disk in subsequent generations of the dashboards.

The issue:
AFTER it has successfully saved the cached activations, there's an OOM when it first tries to use these cached activations. This does not happen on the second and later batches - on those batches where we are just loading the cached activations from disk, there is no OOM and the files are generated correctly.

Workaround:
Rerun the first batch, which causes it to load from disk and not have the OOM.

Potential Cause:
Something about caching the activations to disk the first time is using up memory that isn't being freed.

Log:

/home/ubuntu/saed/.venv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 2/2 [04:00<00:00, 120.15s/it]
WARNING:root:You are not using LayerNorm, so the writing weights can't be centered! Skipping
Loaded pretrained model mistral-7b into HookedTransformer
Resolving data files: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:00<00:00, 315.37it/s]
Skipping sparsity because sparsity_threshold was set to 1
Tokens don't exist, making them.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4096/4096 [00:27<00:00, 147.72it/s]
  0%|                                                                                                                         | 0/512 [00:00<?, ?it/s]========== Running Batch #1 ==========
  0%|                                                                                                                         | 0/512 [24:41<?, ?it/s]
Traceback (most recent call last):is: 100%|█████████████████████████████████████████████████████████████████████████| 384/384 [24:41<00:00,  3.82s/it]
  File "/home/ubuntu/saed/sae_dashboard/neuronpedia/make_batch.py", line 40, in <module>                                      | 0/128 [00:00<?, ?it/s]
    runner.run()
  File "/home/ubuntu/saed/sae_dashboard/neuronpedia/neuronpedia_runner.py", line 361, in run
    feature_data = SaeVisRunner(feature_vis_config_gpt).run(
  File "/home/ubuntu/saed/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/saed/sae_dashboard/sae_vis_runner.py", line 112, in run
    feature_stats = FeatureStatistics.create(
  File "/home/ubuntu/saed/sae_dashboard/utils_fns.py", line 526, in create
    quantile_data = torch.quantile(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.50 GiB. GPU 
Extracting vis data from cached data:   0%|                                                                                   | 0/128 [24:42<?, ?it/s]
Forward passes to cache data for vis: 100%|█████████████████████████████████████████████████████████████████████████| 384/384 [24:42<00:00,  3.86s/it]


╭─────────────────────────────────────────────────────────── Running Command for Batch #2 ───────────────────────────────────────────────────────────╮
│                                                                                                                                                    │
│ python make_batch.py res-je /home/ubuntu/Mistral-7B-Residual-Stream-SAEs/mistral_7b_layer_8                                                        │
│ /home/ubuntu/saed/sae_dashboard/neuronpedia/../../neuronpedia_outputs/mistral-7b_res-je_blocks.8.hook_resid_pre 1 4096 24576 128 128 2 2           │
│                                                                                                                                                    │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
/home/ubuntu/saed/.venv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.10it/s]
WARNING:root:You are not using LayerNorm, so the writing weights can't be centered! Skipping
Loaded pretrained model mistral-7b into HookedTransformer
Resolving data files: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:00<00:00, 321.29it/s]
Skipping sparsity because sparsity_threshold was set to 1
Tokens exist, loading them.
  0%|                                                                                                                         | 0/512 [00:00<?, ?it/s]========== Running Batch #2 ==========

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.