Git Product home page Git Product logo

saelens's Introduction

Screenshot 2024-03-21 at 3 08 28 pm

SAE Lens

PyPI License: MIT build Deploy Docs codecov

SAELens exists to help researchers:

  • Train sparse autoencoders.
  • Analyse sparse autoencoders / research mechanistic interpretability.
  • Generate insights which make it easier to create safe and aligned AI systems.

Please refer to the documentation for information on how to:

  • Download and Analyse pre-trained sparse autoencoders.
  • Train your own sparse autoencoders.
  • Generate feature dashboards with the SAE-Vis Library.

SAE Lens is the result of many contributors working collectively to improve humanities understanding of neural networks, many of whom are motivated by a desire to safeguard humanity from risks posed by artificial intelligence.

This library is maintained by Joseph Bloom and David Chanin.

Loading Pre-trained SAEs.

Pre-trained SAEs for various models can be imported via SAE Lens. See this page in the readme for a list of all SAEs.

Tutorials

Join the Slack!

Feel free to join the Open Source Mechanistic Interpretability Slack for support!

Citation

Please cite the package as follows:

@misc{bloom2024saetrainingcodebase,
   title = {SAELens
   author = {Joseph Bloom, David Chanin},
   year = {2024},
   howpublished = {\url{https://github.com/jbloomAus/SAELens}}
}}

saelens's People

Contributors

benw8888 avatar bmillwood avatar canrager avatar chanind avatar ckkissane avatar curt-tigges avatar dtch1997 avatar evanhanders avatar hijohnnylin avatar hufy-dev avatar ianand avatar jbloomaus avatar joshengels avatar lewington-pitsos avatar lucyfarnik avatar neelnanda-io avatar nelsong-c avatar nlpet avatar oli-clive-griffin avatar phylliida avatar robertzk avatar roganinglis avatar schmatz avatar shehper avatar slavachalnev avatar themachinefan avatar tommcgrath avatar weissercn avatar wllgrnt avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

saelens's Issues

Support OthelloGPT

I did a hack weekend with Andy, Neel and others where we trained OthelloGPT SAEs fairly easily.

I think it should be fairly easy to provide OthelloGPT support based on their code here. The main trick will be to make an argument that tells the activation store not to do tokenization related stuff.

Use logging instead of print in codebase

Currently, logging info is output using print(), but this makes it hard to control the verbosity level of the output, and cannot distinguish between debugging / info statements and warnings / errors. For instance, in sae_training/config.py, creating a config object prints a lot of statements, but these are a mix of debugging / info / warnings, and would be better suited to using the Python logging module.

Error running SAE in evaluation notebook

Problem: When running the evaluating_your_sae.ipynb I get an error when trying to do the L0 test in the 3rd code block.

More specifically, when running the sparse_autoencoder(cache[sparse_autoencoder.cfg.hook_point])

Since I didn't have a chance to train the model myself and am using the
"sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_6144:v2/1200001024_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_6144.pt"
checkpoint, I wanted to flag this in case someone else ran into this issue.

If I missed something obvious, I apologize! Thank you for open-sourcing this research :)

The last line in the screenshot says 'NoneType' object has no attribute 'sum'.
Screen Shot 2024-02-06 at 5 18 54 PM

Publish on PyPI

Could this library be published on PyPI so it could be installed via pip install instead of needing to clone the git repo? It would make it much easier to install and use this library, as publishing on PyPI is usually the preferred way of sharing Python code.

The quickest path to this would be to do the following:

  1. Move the library code into a folder with the same name as the library. If you want to call the library mats_sae_training so you can pip install mats_sae_training, then create a mats_sae_training folder and the sae_training and sae_analysis folders should be moved into that folder. Imports will then be from mats_sae_training.sae_training import ....
  2. Use a modern packing tool (Poetry, PDM, or Hatch are all good choices)
  3. Remove the requirements.txt and instead split dependencies between the main dependencies and dev dependencies in pyproject.toml.
  4. Remove .gitmodules - it seems this isn't used anyway.

From there, you just need to run the publish command for the corresponding packing library you choose from above (e.g. poetry publish for poetry, or pdm publish for pdm, hatch publish for hatch), and you're basically done.

You can also set up automatic deployment with Github Actions using Semantic release so deployment happens automatically when committing to main, but that's not strictly required.

I can set up a PR with these changes if it's helpful.

Any suggestions for tuning the training parameters?

Thank you for providing an open-source dictionary learning repository. I have trained on other models similar to GPT-2. Despite numerous parameter adjustments, the training results still exhibit a large number of dead features as shown below.
11

Type-checking is not enforced

It looks like this library does make good use of type-hints, which is excellent, but these hints are not enforced and thus could be incorrect. Ideally, these types can be checked as part of the linting process using a type-checker. Some good choices are MyPy or Pyright. I can set up a PR implementing either of these if it's helpful.

Add toy models architecture built on HookedRootModule

Per discussion in slack, SAELens has stale code for toy models. It would be great to bring that code back to life.

Toy models should be built on top of a class that itself is a subclass of HookedRootModule from transformer_lens, so that toy models can be run_with_cache and code for toy models and language models actually looks pretty similar.

For starters, the Anthropic ReLU output model is the only toy model we need working, but keeping generality for other types of models to be included in future PRs would be great!

Note: I'm going to be working on this this week (week of Apr 15), and wanted to make sure that it was visible that this is being thought about and worked on.

Test on Python 3.12 in CI

It looks like flake8 gives slightly different output in 3.12 vs 3.11 and 3.10, so we should add this to the CI matrix for consistency

SAEGroup config typing is incorrect

Currently, SAEGroup is configured using the same config object as normal SAE objects, but the typings are incorrect since SAEGroups allow lists for each item in the config which turn into a SAE config for every permutation of values in lists in the config, while the config for SAEs do not. This means that configuring a SAEGroup currently will report typing errors.

I can envision two possible solutions:

Solution 1: Create a dedicated LanguageModelSAEGroupRunnerConfig class

One solution is for SAEGroup and SparseAutencoder classes to each have their own dedicated config class to address this issue, likely with similar fields but with SAEGroup allowing lists of items. This means being less DRY (don't repeat yourself) in the config code, but I feel it's justified in this case for the simplicity and clarity it confers to the user. These two configs do correspond to separate concepts in the code, so it makes sense to have separate classes. We can add a test case to ensure these classes have the same keys so we don't accidentally add a key to one class but not the other.

Solution 2: Create SAEGroup from a list of SAEConfigs, with helpers for generating permutations

We could have SAEGroup take in a list of SAEConfig objects directly, which would solve the typing issue as there's no need then for a separate config for the SAEGroup, and would also give the user more direct control over what specific combinations of params they want to be running if they don't want every permutation. We could create some helper functions which can generate this list of SAEConfigs using combinations of values to enable the current behavior too. Something like:

# Create based on combination of params on a base config, to match current behavior
base_config = LanguageModelSAEGroupRunnerConfig(...)
sae_group = SAEGroup.from_combinations(base_config, {
    'l1_coefficient': [1e-3, 1e-4, 1e-5],
    'd_sae': [1000, 5000],
})

# Create directly from a list of SAEConfigs
sae_configs = [
    LanguageModelSAEGroupRunnerConfig(...),
    LanguageModelSAEGroupRunnerConfig(...),
    LanguageModelSAEGroupRunnerConfig(...),
)
sae_group = SAEGroup(sae_configs)

I like solution 2 best, as it feels like it more directly addresses what we want a SAEGroup to be, which is simply a collection of SAEs with different params. It allows more flexibility than is possible currently since you can choose combinations of configs that don't need to include every possible permutation of options if you just have a few combinations you want to try out, while still allowing trying out all combinations of params if you'd like instead. It also avoids the DRY issues of having to create a separate, nearly identical config class, and the SAEGroup.from_combinations() method makes it really obvious what param combinations are being tried out since it lists them specifically, instead of needing to skim the whole config looking for items in lists.

Colons in filepaths on Windows

Windows cannot deal with colons in filepaths, so I can't clone the repo:

error: invalid path 'artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_6144:v2/1200001024_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_6144.pt'

Proposal: Add helper function to apply a SAE or SAEDict to a model

One of the most common use-cases for SAELens is to apply a pre-trained SAE visible on neuronpedia to a model to analyze activations. Currently we don't provide any helper to accomplish this - the user has to extract the hook point from the SAE, run the model, collect activations, and apply the SAE to the collected activations. Since this is likely a very common use-case we should make this super easy to do.

One possibility would be to have a context manager on SparseAutoencoder and SparseAutoencoderDict, enabling usage like:

with sae.track_activations(model) as sae_cache:
    model(...)

inputs = sae_cache.inputs
feature_acts = sae_cache.feature_activations
outputs = sae_cache.outputs

with sae_dict.track_activations(model) as sae_dict_cache:
    model(...)

# dict outputs are the same as above, except in a dict for each SAE
sae1_inputs = sae_dict_cache['sae1_key'].inputs
...

Alternatively, we could add a apply_and_run_with_cache() method which both attaches the SAE and calls run_with_cache() on the underlying model, returning both the model cache and SAE cache

model_cache, sae_cache = sae.apply_and_run_with_cache(model, ...)

# likewise, for SAEDict, return a dict of SAE caches
model_cache, sae_dict_cache = sae_dict.apply_and_run_with_cache(model, ...)

Some other things to think about:

  • Do we want a separate method to inject the SAE into the model computation vs just tracking activations on the side?
  • If we have a method to inject the SAE into model computation, should we have a way of adding in the SAE error term so we can attribute gradients to the SAE vs error term while ensuring the model computation is unchanged?
  • Should we just always inject the SAE into model comptuation + error term by default, so the SAE gets gradients, tracks features, while not modifying model output?

[Proposal] Implement gated SAEs

Proposal

Implemented gated SAEs as described here:
https://arxiv.org/abs/2404.16014

Motivation

Gated SAEs are reported to be strictly better than regular SAEs

Pitch

Implement a new GatedSparseAutoencoder class, or whatever way is best

Alternatives

A clear and concise description of any alternative solutions or features you've considered, if any.

Additional context

Add any other context or screenshots about the feature request here.

Checklist

  • I have checked that there is no similar issue in the repo (required)

[Bug Report] Unsupported scheduler error when training SAE

If you are submitting a bug report, please fill in the following details and use the tag [bug].

Describe the bug
Example SAE training code encounters an error in sae-lens==0.7.0

Code example

import os

import torch

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_lens import LanguageModelSAERunnerConfig, language_model_sae_runner

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="gpt2-small",
    hook_point="blocks.2.hook_resid_pre",
    hook_point_layer=2,
    d_in=768,
    dataset_path="Skylion007/openwebtext",
    is_dataset_tokenized=False,
    # SAE Parameters
    expansion_factor=64,
    b_dec_init_method="geometric_median",
    # Training Parameters
    lr=0.0004,
    l1_coefficient=0.00008,
    lr_scheduler_name="constantwithwarmup",
    train_batch_size=4096,
    context_size=128,
    lr_warm_up_steps=5000,
    # Activation Store Parameters
    n_batches_in_buffer=128,
    training_tokens=1_000_000 * 300,
    store_batch_size=32,
    # Dead Neurons and Sparsity
    use_ghost_grads=True,
    feature_sampling_window=1000,
    dead_feature_window=5000,
    dead_feature_threshold=1e-6,
    # WANDB
    log_to_wandb=True,
    wandb_project="gpt2",
    wandb_entity=None,
    wandb_log_frequency=100,
    # Misc
    device="cuda",
    seed=42,
    n_checkpoints=10,
    checkpoint_path="checkpoints",
    dtype=torch.float32,
)

sparse_autoencoder = language_model_sae_runner(cfg)

Error message:

Traceback (most recent call last):
  File "/home/daniel/ml_workspace/SAELens/scripts/train_gated_sae.py", line 51, in <module>
    sparse_autoencoder = language_model_sae_runner(cfg)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/lm_runner.py", line 34, in language_model_sae_runner
    sparse_autoencoder = train_sae_on_language_model(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/train_sae_on_language_model.py", line 92, in train_sae_on_language_model
    return train_sae_group_on_language_model(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/train_sae_on_language_model.py", line 132, in train_sae_group_on_language_model
    train_contexts = {
                     ^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/train_sae_on_language_model.py", line 133, in <dictcomp>
    name: _build_train_context(sae, total_training_steps)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/train_sae_on_language_model.py", line 302, in _build_train_context
    scheduler = get_scheduler(
                ^^^^^^^^^^^^^^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/optim.py", line 38, in get_scheduler
    main_scheduler = _get_main_scheduler(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/home/daniel/ml_workspace/SAELens/sae_lens/training/optim.py", line 98, in _get_main_scheduler
    raise ValueError(f"Unsupported scheduler: {scheduler_name}")
ValueError: Unsupported scheduler: constantwithwarmup

System Info
Describe the characteristic of your environment:

  • Installed from source via poetry install
  • Ubuntu 22.04
  • Python 3.11.9

Checklist

  • I have checked that there is no similar issue in the repo (required)

[Proposal] add the loading SAE from Huggingface utils to TransformerLens

Proposal, Motivation, Pitch

The great utils here:

def download_sae_from_hf(
would really help TransformerLensOrg/TransformerLens#584 work for arbitrary SAEs on HuggingFace (as suggested by @danbraunai ).

How would @jbloomAus you feel about me adding (a version of) that util to TransformerLens and then making the SAELens code use those instead?

Alternatives

Copy and pasted code in TransformerLens

Additional context

N/A

Checklist

  • I have checked that there is no similar issue in the repo (required)

[Proposal] Current session loading process can be very confusing and should be simpler

Not a real and complete proposal, mostly a note from a chat with @jbloomAus.

Proposal

There should be a simple and documented way to load a (model, sae, dataset).

Ideally, there should be a way to load (model, sae) or model then sae, and a way to load a dataset that is more straightforward (or just documented) than ActivationsStore (which I wouldn't have guessed was the main dataloader).

Motivation

The version from the docs is broken, and somewhat confusion (e.g. it's hard to know how to use another dataset with that API).

Checklist

  • [ x ] I have checked that there is no similar issue in the repo

Clarify batch size units

Currently 'batch size' refers to two distinct units in different places: tokens and sequences. The referenced comment suggests a naming scheme that would distinguish this.

          @tomMcGrath  @tomMcGrath We previously has "train_batch_size" which is a token batch size for the SAE and a "store batch size" which was a prompts batch size for generating activations. The latter was then reused during evaluations, but we've now seperated them out. I think the following names might be best:
  • train_batch_size_tokens
  • store_batch_size_prompts
  • eval_batch_size_prompts
    I generally like longer names if they are sufficiently clear but will admit these names are long and this could be annoying.

Originally posted by @jbloomAus in #128 (comment)

[Proposal] Move to version 1.0.0 to fix projects that depend on SAELens using caret ^ with semver

Proposal

Move the version number to 1.0.0. This does not propose any code changes.

Motivation

sae_vis depends on saelens and uses caret notation ^0.5.0 to pick up SAELens. However, SAELens is now 0.6.0 and it is not picking up 0.6.0. This is because the caret works differently for projects that are lower than v1.0.0 - the caret will allow only 0.5.0 to 0.5.x, not anything 0.6.0 or above.
This means that sae_vis and other projects that depend on SAELens will not pick up 0.6.0 and later unless they specifically update their package dependency version to ^0.6.0.

Pitch

Move the version number to 1.0.0, and use the convention of major.minor.patch where MAJOR = backward incompatible changes, MINOR = backward compatible feature updates, PATCH = backward compatible fixes.

Alternatives

An alternative is to stick with version <1.0.0, but make sure that:

  • In an update from v0.x.y to v0.x+1.y, ensure that this is actually a backward incompatible change. Otherwise update it to v0.x.y to v0.x.y+1 instead.
  • Document what the backward incompatible change is when v0.x.y is updated to v0.x+1.y

Additional context

Additional context here (it's on a node blog but semver works the same way), which addresses the debate in general.

The main reason to not move to 1.0.0 is that the maintainer may not think it's "finished" or "ready enough" in a way that others might think a 1.0.0 might require. This is pretty subjective, I leave that up to the maintainers to decide this. I think at the very least, we should do the alternative mentioned above.

Checklist

  • I have checked that there is no similar issue in the repo (required)

`DashboardRunner` errors

Describe the bug

Attempt 1 (https://github.com/joelburget/SAELens/blob/1f59808ca9fd76e0e946ea6e631c65a4f8482240/eval_polar.py):

runner = DashboardRunner(sae_path=sae_path)
runner.run()

Result:

Forward passes to cache data for vis: 100%|█████████████████████████████████████████████████████████████████████████████████| 24/24 [03:04<00:00,  7.71s/it]
Extracting vis data from cached data: 100%|█████████████████████████████████████████████████████████████████████████████| 1024/1024 [03:04<00:00,  5.54it/s]
  0%|                                                                                                                                 | 0/8 [03:05<?, ?it/s]
Traceback (most recent call last):
  File "/workspace/SAELens/eval_polar.py", line 16, in <module>
    runner.run()
  File "/workspace/SAELens/sae_lens/analysis/dashboard_runner.py", line 376, in run
    feature_data.save_feature_centric_vis(
  File "/workspace/SAELens/.venv/lib/python3.10/site-packages/sae_vis/data_storing_fns.py", line 1073, in save_feature_centric_vis
    assert self.model is not None
AssertionError

This hack (joelburget@08bddf0) seems to fix the immediate issue.

Attempt 2 (with hack) result:

Forward passes to cache data for vis: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [03:10<00:00,  7.92s/it]
Extracting vis data from cached data: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1024/1024 [03:10<00:00,  5.39it/s]
  0%|                                                                                                                                                                                 | 0/8 [03:10<?, ?it/s]
Traceback (most recent call last):
  File "/workspace/SAELens/eval_polar.py", line 15, in <module>
    runner.run()
  File "/workspace/SAELens/sae_lens/analysis/dashboard_runner.py", line 376, in run
    feature_data.save_feature_centric_vis(
  File "/workspace/SAELens/.venv/lib/python3.10/site-packages/sae_vis/data_storing_fns.py", line 1085, in save_feature_centric_vis
    html_obj = feature_data._get_html_data_feature_centric(
  File "/workspace/SAELens/.venv/lib/python3.10/site-packages/sae_vis/data_storing_fns.py", line 860, in _get_html_data_feature_centric
    html_obj += component._get_html_data(
  File "/workspace/SAELens/.venv/lib/python3.10/site-packages/sae_vis/data_storing_fns.py", line 259, in _get_html_data
    assert cfg.n_rows <= len(self.bottom_logits)
AssertionError

System Info
I'm running within that SAELens repo, dependencies installed via poetry.

(sae-lens-py3.10) root@e8e8a5a21111:/workspace/SAELens# python3 --version
Python 3.10.12
(sae-lens-py3.10) root@e8e8a5a21111:/workspace/SAELens# uname -a
Linux e8e8a5a21111 5.15.0-89-generic #99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux

Additional context
It would be helpful to see an example of how to use DashboardRunner. It would also be useful to see an example of how to use the wandb integration (joelburget@1f59808).

Checklist

  • I have checked that there is no similar issue in the repo (seems unrelated to #72)

[Bug Report] scaling_factor broke sae_vis compatibility

Describe the bug
I'm using this notebook on an SAE I created: basic_loading_and_analysing.ipynb. I get an error that appears to be because the scaling_factor was added to the SparseAutoencoder class, which sae_vis is not expecting.

Code example

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[21], line 17
     14 print(type(sparse_autoencoder))
     15 print(sparse_autoencoder.state_dict().keys())
---> 17 sae_vis_data_gpt = SaeVisData.create(
     18     encoder=sparse_autoencoder,
     19     model=model,
     20     tokens=all_tokens,  # type: ignore
     21     cfg=feature_vis_config_gpt,
     22 )

File /opt/conda/lib/python3.10/site-packages/sae_vis/data_storing_fns.py:1017, in SaeVisData.create(cls, encoder, model, tokens, cfg, encoder_B)
   1014 # If encoder isn't an AutoEncoder, we wrap it in one
   1015 if not isinstance(encoder, AutoEncoder):
   1016     assert (
-> 1017         set(encoder.state_dict().keys()) == {"W_enc", "W_dec", "b_enc", "b_dec"}
   1018     ), "If encoder isn't an AutoEncoder, it should have weights 'W_enc', 'W_dec', 'b_enc', 'b_dec'"
   1019     d_in, d_hidden = encoder.W_enc.shape
   1020     device = encoder.W_enc.device

AssertionError: If encoder isn't an AutoEncoder, it should have weights 'W_enc', 'W_dec', 'b_enc', 'b_dec'

When running this cell:

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

test_feature_idx_gpt = list(range(10)) + [14057]

feature_vis_config_gpt = SaeVisConfig(
    hook_point=hook_point,
    features=test_feature_idx_gpt,
    batch_size=2048,
    minibatch_size_tokens=128,
    verbose=True,
)

sae_vis_data_gpt = SaeVisData.create(
    encoder=sparse_autoencoder,
    model=model,
    tokens=all_tokens,  # type: ignore
    cfg=feature_vis_config_gpt,
)

This print statement shows a scaling_factor key that isn't in the above assert:

print(type(sparse_autoencoder))
print(sparse_autoencoder.state_dict().keys())

<class 'sae_lens.training.sparse_autoencoder.SparseAutoencoder'>
odict_keys(['W_enc', 'b_enc', 'W_dec', 'b_dec', 'scaling_factor'])

System Info
Describe the characteristic of your environment:

  • Describe how transformer_lens was installed (pip, docker, source, ...) pip
  • What OS are you using? (Linux, MacOS, Windows) Ubuntu
  • Python version (We suppourt 3.10 -3.12 currently) 3.10

Checklist

  • I have checked that there is no similar issue in the repo (required)

Use flake8 `extend-select` instead of `select`

it looks like flake8 linting is configured using select = E9, F63, F7, F82, but this effectively overrides all the flake8 default linting rules and runs ONLY those 4 rules. Most likely, the intention was to use extend-select = E9, F63, F7, F82, so that all the default flake8 rules still apply, but those 4 rules apply as well. More info on extend-select vs select: https://flake8.pycqa.org/en/latest/user/options.html#cmdoption-flake8-extend-select

It also looks like the Github action for flake8 is manually setting options as well, and running twice. Why is this the case rather than simply running flake8 . in the action, using the settings specified in .flake8?

FIX: neuronpedia_runner and tutorial notebook

It's broken after some recent refactoring.

  • correlated_neurons_l1 and frac_nonzero seem to be missing or relocated
  • should reduce redundancy in outputs: logits bg values, freq_bar values

generating_sae_dashboards.ipynb is not working

Two issues with generating_sae_dashboards.ipynb causing it to fail:

  1. eindex is missing.
    Potential Fix (my current local workaroud): add back eindex
    Not creating a pull request due to possibility that the correct fix is for eindex to be a dependency in sae_vis, not SAELens.

  2. Error running "Use runner" cell with no modifications: ModuleNotFoundError: No module named 'sae_training'

Full error:

{
	"name": "ModuleNotFoundError",
	"message": "No module named 'sae_training'",
	"stack": "---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[2], line 6
      3 FILENAME = f\"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt\"
      4 path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
----> 6 obj = torch.load(path, map_location=device)
      7 state_dict = obj[\"state_dict\"]
      8 assert set(state_dict.keys()) == {\"W_enc\", \"b_enc\", \"W_dec\", \"b_dec\"}

File ~/Documents/Projects/SAELens/.venv/lib/python3.12/site-packages/torch/serialization.py:1026, in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
   1024             except RuntimeError as e:
   1025                 raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
-> 1026         return _load(opened_zipfile,
   1027                      map_location,
   1028                      pickle_module,
   1029                      overall_storage=overall_storage,
   1030                      **pickle_load_args)
   1031 if mmap:
   1032     raise RuntimeError(\"mmap can only be used with files saved with \"
   1033                        \"`torch.save(_use_new_zipfile_serialization=True), \"
   1034                        \"please torch.save your checkpoint with this option in order to use mmap.\")

File ~/Documents/Projects/SAELens/.venv/lib/python3.12/site-packages/torch/serialization.py:1438, in _load(zip_file, map_location, pickle_module, pickle_file, overall_storage, **pickle_load_args)
   1436 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
   1437 unpickler.persistent_load = persistent_load
-> 1438 result = unpickler.load()
   1440 torch._utils._validate_loaded_sparse_tensors()
   1441 torch._C._log_api_usage_metadata(
   1442     \"torch.load.metadata\", {\"serialization_id\": zip_file.serialization_id()}
   1443 )

File ~/Documents/Projects/SAELens/.venv/lib/python3.12/site-packages/torch/serialization.py:1431, in _load.<locals>.UnpicklerWrapper.find_class(self, mod_name, name)
   1429         pass
   1430 mod_name = load_module_mapping.get(mod_name, mod_name)
-> 1431 return super().find_class(mod_name, name)

ModuleNotFoundError: No module named 'sae_training'"
}

[Bug Report] make check-ci fails one test on CUDA >= 10.2

Describe the bug
I ran make check-ci on an AWS EC2 instance and one test fails:

make check-ci

>           return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
E           RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility

../.cache/pypoetry/virtualenvs/sae-lens-ODzxvHul-py3.10/lib/python3.10/site-packages/torch/functional.py:385: RuntimeError

FAILED tests/unit/training/test_cache_activations_runner.py::test_cache_activations_runner - RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithm...
================================================ 1 failed, 94 passed, 4 skipped, 19 warnings in 46.62s ================================================

I followed its suggestion and it succeeds:
CUBLAS_WORKSPACE_CONFIG=:4096:8 make check-ci
Makes all tests pass

Not sure if there is something to fix here, but Joseph asked me to make sure make check-ci passes on the new workflow I'm building.

System Info
Ubuntu on AWS EC2
AMI: Deep Learning OSS Nvidia Driver AMI GPU PyTorch 2.2.0 (Ubuntu 20.04) 20240507
AMI ID: ami-02a07d31009cc8717

Checklist

  • I have checked that there is no similar issue in the repo (required)

Move to safetensors to save autoencoders and activations

Currently saving and loading uses torch.save(), but this allows malicious code to be executed while loading. Huggingface developed a format called safetensors (https://huggingface.co/docs/safetensors/en/index) which cannot be used to load executable code, and is also much faster than torch.save/torch.load for loading tensors (supposedly 76x faster). I think it also allows reading the size of the stored tensors without loading the entire file, which would be a big win when initializing ActivationsStore.

Question about `.float()` calls in codebase

The codebase has several areas where .float() is called on tensors, for instance here and here. Are these calls necessary? This will make the tensor float32 even if the rest of the calculation is a different dtype, like float16

Support Attn Out SAEs

We should be able to fairly trivially support the SAEs described here. Validating this seems important.

KeyError in `geometric_medians` in training

Running tinystories training gives an error:

"mats_sae_training/sae_training/train_sae_on_language_model.py", line 90, in train_sae_on_language_model
    geometric_medians[sae_layer_id].append(median)
    ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
KeyError: 0

This can be solved by setting b_dec_init_method="mean" currently.

code:

    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    cfg = LanguageModelSAERunnerConfig(
        # Data Generating Function (Model + Training Distibuion)
        model_name="tiny-stories-2L-33M",
        hook_point="blocks.1.mlp.hook_post",
        hook_point_layer=1,
        d_in=4096,
        dataset_path="roneneldan/TinyStories",
        is_dataset_tokenized=False,
        # SAE Parameters
        expansion_factor=4,
        # Training Parameters
        lr=1e-4,
        l1_coefficient=3e-4,
        train_batch_size=4096,
        context_size=128,
        # Activation Store Parameters
        n_batches_in_buffer=128,
        total_training_tokens=1_000_000 * 10,  # want 500M eventually.
        store_batch_size=32,
        # Resampling protocol
        feature_sampling_window=2500,  # Doesn't currently matter.
        dead_feature_window=1250,
        dead_feature_threshold=0.0005,
        # Misc
        device=device,
        seed=42,
        n_checkpoints=0,
        checkpoint_path="checkpoints",
        dtype=torch.float32,
        # Wandb
        log_to_wandb=True,
        wandb_project="mats_sae_training_language_benchmark_tests",
        wandb_entity=None,
        wandb_log_frequency=10,
    )

    sparse_autoencoder = language_model_sae_runner(cfg)

Error loading pre-trained SAEs from Huggingface

In the README, the specified way to load the pre-trained SAE is as follows:

import torch 
from sae_training.utils import LMSparseAutoencoderSessionloader
from huggingface_hub import hf_hub_download

layer = 8 # pick a layer you want.
REPO_ID = "jbloom/GPT2-Small-SAEs"
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
model, sparse_autoencoder, activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path = path
)
sparse_autoencoder.eval()

However, running this code verbatim results in the following error:

AttributeError                            Traceback (most recent call last)
[<ipython-input-4-96a01d6d9844>](https://localhost:8080/#) in <cell line: 9>()
      7 FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"
      8 path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
----> 9 model, sparse_autoencoder, activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
     10     path = path
     11 )

[/content/mats_sae_training/sae_training/utils.py](https://localhost:8080/#) in load_session_from_pretrained(cls, path)
     49 
     50         sparse_autoencoders = SAEGroup.load_from_pretrained(path)
---> 51         model, _, activations_loader = cls(sparse_autoencoders.cfg).load_session()
     52 
     53         return model, sparse_autoencoders, activations_loader

AttributeError: 'dict' object has no attribute 'cfg'

Likely, this is related to the change to use SAEGroup as the core unit of the library rather than a single SAE.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.