Git Product home page Git Product logo

openmoss / language-model-saes Goto Github PK

View Code? Open in Web Editor NEW
20.0 2.0 3.0 24.37 MB

For OpenMOSS Mechanistic Interpretability Team's Sparse Autoencoder (SAE) research.

Python 4.02% Jupyter Notebook 95.51% JavaScript 0.02% HTML 0.01% TypeScript 0.42% CSS 0.01% Dockerfile 0.01% Makefile 0.01% Batchfile 0.01%
interpretability mechanistic-interpretability sparse-autoencoders sparse-dictionary

language-model-saes's Introduction

Language-Model-SAEs

This repo aims to provide a general codebase for conducting dictionary-learning-based mechanistic interpretability research on Language Models (LMs). It powers a configurable pipeline for training and evaluating GPT-2 dictionaries, and provides a set of tools (mainly a React-based webpage) for analyzing and visualizing the learned dictionaries.

The design of the pipeline (including the configuration and some training detail) is highly inspired by the mats_sae_training project and heavily relies on the TransformerLens library. We thank the authors for their great work.

Getting Started with Mechanistic Interpretability and Dictionary Learning

If you are new to the concept of mechanistic interpretability and dictionary learning, we recommend you to start from the following paper:

Furthermore, to dive deeper into the inner activations of LMs, it's recommended to get familiar with the TransformerLens library.

Installation

Currently, the codebase use pdm to manage the dependencies, which is an alternative to poetry. To install the required packages, just install pdm, and run the following command:

pdm install

This will install all the required packages for the core codebase. Note that if you're in a conda environment, pdm will directly take the current environment as the virtual environment for current project, and remove all the packages that are not in the pyproject.toml file. So make sure to create a new conda environment (or just deactivate conda, this will use virtualenv by default) before running the above command. A forked version of TransformerLens is also included in the dependencies to provide the necessary tools for analyzing features.

If you want to use the visualization tools, you also need to install the required packages for the frontend, which uses bun for dependency management. Follow the instructions on the website to install it, and then run the following command:

cd ui
bun install

It's worth noting that bun is not well-supported on Windows, so you may need to use WSL or other Linux-based solutions to run the frontend, or consider using a different package manager, such as pnpm or yarn.

Launch an Experiment

We provide both a programmatic and a configuration-based way to launch an experiment. The configuration-based way is more flexible and recommended for most users. You can find the configuration files in the examples/configuration directory, and modify them to fit your needs. The programmatic way is more suitable for advanced users who want to customize the training process, and you can find the example scripts in the examples/programmatic directory.

To simply begin a training process, you can run the following command:

lm-saes train examples/configuration/train.toml

which will start the training process using the configuration file examples/configuration/train.toml.

To analyze a trained dictionary, you can run the following command:

lm-saes analyze examples/configuration/analyze.toml --sae <path_to_sae_model>

which will start the analysis process using the configuration file examples/configuration/analyze.toml. The analysis process requires a trained SAE model, which can be obtained from the training process. You may need launch a MongoDB server to store the analysis results, and you can modify the MongoDB settings in the configuration file.

Generally, our configuration-based pipeline uses outer layer settings as default of the inner layer settings. This is beneficial for easily building deeply nested configurations, where sub-configurations can be reused (such as device and dtype settings). More detail will be provided future.

Visualizing the Learned Dictionary

The analysis results will be saved using MongoDB, and you can use the provided visualization tools to visualize the learned dictionary. First, start the FastAPI server by running the following command:

uvicorn server.app:app --port 24577
# You may want to modify some environmental settings in server/.env.example to server/.env, and run with these environmental variables:
# uvicorn server.app:app --port 24577 --env-file server/.env

Then, copy the ui/.env.example file to ui/.env and modify the VITE_BACKEND_URL to fit your server settings (by default, it's http://localhost:24577), and start the frontend by running the following command:

cd ui
bun dev --port 24576

That's it! You can now go to http://localhost:24576 to visualize the learned dictionary and its features.

Development

We highly welcome contributions to this project. If you have any questions or suggestions, feel free to open an issue or a pull request. We are looking forward to hearing from you!

TODO: Add development guidelines

Citation

Please cite this library as:

@misc{Ge2024OpenMossSAEs,
    title  = {OpenMoss Language Model Sparse Autoencoders},
    author = {Xuyang Ge, Fukang Zhu, Junxuan Wang, Wentao Shu, Lingjie Chen, Zhengfu He},
    url    = {https://github.com/OpenMOSS/Language-Model-SAEs},
    year   = {2024}
}

language-model-saes's People

Contributors

dest1n1s avatar hzfinfdu avatar frankstein73 avatar smallmelon-l avatar starconnor avatar

Stargazers

fei zuo avatar  avatar handmasterxc avatar CooperLeong avatar Frederik Fix avatar  avatar syr-bloom avatar  avatar Solaris avatar Yuheng avatar  Alan May avatar Ningyu Xu avatar  avatar LINGJIE CHEN avatar 2+c avatar Qinyuan Cheng avatar  avatar JunZhan2000 avatar  avatar  avatar

Watchers

Xipeng Qiu avatar  avatar

language-model-saes's Issues

[Proposal] Documentation coverage and static documentation site

It is much easier for people (who may be new to mechanistic interpretability) to get started with detailed tutorial and documentation. Currently this project lacks documentation and comments in many modules. We should raise the documentation coverage, to ensure detailed explanation of every part of our library.

Furthermore, we should consider building a static documentation site with tools like MkDocs. This helps people get an overview of the usage of the library without actually downloading it.

[Proposal] Server app & frontend need optimization

I have noticed 2 problems in current frontend + backend service.

  • [urgent] byte_decoder does not work in llama-3 8B decoder. What is a workaround?

AttributeError: 'PreTrainedTokenizerFast' object has no attribute 'byte_decoder'.

  • [low priority] Larger models have longer context window. Will this make visualization messier?

Simple truncating might not work. It may fail to catch longer range dependency.

I propose to make some changes in the frontend. What about a preview (shorter local context) which can be expanded to full context? @dest1n1s

[Proposal] Support from_pretrained

We aim to build a really useful infrastructure for SAE research.

Maybe we need to open source our SAEs with a huggingface-type interface. This may require some sort of cloud service for storage or stuff like that? @dest1n1s

This has lower priority since this feature is only of use after we trained considerably good SAEs on larger language models.

[Proposal] Accelerate Inference in TransformerLens

The main bottleneck of SAE training lies in activation gen. It can be annoying when we try to work with larger models.

Try to accelerate TL inference, especially attn forward. What are some possible options? FlashAttn2 or VLLM or something?

Since we usually do not cache Q K V, attn forward can be replaced with some faster alternatives.

  • Support FlashAttn-2 in TL

[Proposal] Add Automatic (Unit) Testing and CI Workflows

Automatic testing is fundamental to keep a collaborative developed project from endless bugs corrupting modules that originally work. As for a deep learning library, always running the whole training or analyzing process from the outermost can consume lots of time and computational resources. Minor bugs may also not be triggered in a fixed training setting. Thus, it's necessary to test at different levels to ensure proper functioning as much as possible.

I propose adding the following 4 categories of testing:

  • Unit testing: Testing if every innermost method works well with mock data, e.g. a single forward pass in a minimal SAE, a single generation of activation. Unit testing should cover almost all parts of the library, so every single test is required to run fast.
  • Integrated testing: Testing if low-level modules work with one another properly, e.g. getting feature activation directly from text input (needs co-working of transformers and SAEs), a single training pass, and loading pretrained SAEs from HuggingFace. These tests should cover the common usage of the library at a rather high level. It also requires an acceptable time cost (maybe no more than several seconds). These tests should not depend on GPUs if possible.
  • Acceptance testing: Testing if modules work with a high performance (loss, memory allocated, time cost), e.g. if a pretrained SAE gives a reasonable loss. Some of these tests may require GPUs to run. Failure of these tests may be acceptable in some situations.
  • Benchmarks: Testing the time usage of a complete process and some bottleneck modules.

Continuous Integration (CI) with GitHub workflows should also be added to run testing on every push/PR. PRs should not be merged unless all tests are passed.

[Proposal] Optimize dataset loading and activation store

The current activation store implementation has some drawbacks. Maybe we need to add some new features for streaming activation store and make some optimizations. Below I list some details.

  1. Text Dataset Collate Config
    We need to support SAE training on both pretraining and SFT data, unlike Anthropic's Scaling Monosemanticity in which only pretrained data is used to train SAEs on a supervised finetuned model.

IMO pretraining data should be packed and SFT data should be sorted by length and batched with post paddings. Activations in the residual stream of s should be ignored in SAE training. I believe this is better fitted to real-world distribution.

We need to add into the configuration to configure this.

  • Support two types of activation gen
  1. Shuffle
    When training SAEs with data from multiple distributions, shuffling should be an option to add to diversity of information in a batch. This can be implemented by filling in the activation buffer with random sources.
  • Support buffer filling from multiple sources

[Proposal] Support DDP for activation generation and SAE training.

A natural approach to faster SAE training is data parallel. Maybe we can just simply use DDP to make 8 copies of the TL model to yield activation and synchronize SAE gradients. This may help accelerate activation gen, which is the speed bottleneck for larger LMs.

This may not work on larger size models, say 70B models. Maybe the ultimate solution is a producer-consumer design pattern. Let's leave this for later.

  • Support DDP

[Proposal] Publish on PyPI

We can publish this library on PyPI so that people can use this package simply using pip install lm-saes! However, before this we should first get this library well-tested and well-documented:

Besides, even if we publish this library, people may still need to clone this repository for using the visualization tools. Perhaps we can publish docker images of the backend & frontend of visualization.

[Proposal] Support early stop & partial loading model weights

  1. Early stop @dest1n1s

TransformerLens already supports stop at a given layer. We can utilize this feature to remove unused calculation. This can be applied at any cases that uses run_with_cache.

  1. Partial loading model weights @Frankstein73

Same as Early Stop, what if we do not load these unused weights at all? This may save GPU memory when training early and middle layers.

  • Early stop
  • Partial loading model weights

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.