Git Product home page Git Product logo

pyvene's Introduction



This is a beta release (public testing).

A Library for Understanding and Improving PyTorch Models via Interventions

Interventions on model-internal states are fundamental operations in many areas of AI, including model editing, steering, robustness, and interpretability. To facilitate such research, we introduce pyvene, an open-source Python library that supports customizable interventions on a range of different PyTorch modules. pyvene supports complex intervention schemes with an intuitive configuration format, and its interventions can be static or include trainable parameters.

Getting Started: [Main pyvene 101]

Installation

Since we are currently beta-testing, it is recommended to install pyvene by,

git clone [email protected]:stanfordnlp/pyvene.git

and add pyvene into your system path in python via,

import sys
sys.path.append("<Your Path to Pyvene>")

import pyvene as pv

Alternatively, you can do

pip install git+https://github.com/stanfordnlp/pyvene.git

or

pip install pyvene

Wrap , Intervene and Share

You can intervene with any HuggingFace model as,

import torch
import pyvene as pv
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "meta-llama/Llama-2-7b-hf" # your HF model name.
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)

def zeroout_intervention_fn(b, s): 
    b[:,3] = 0. # 3rd position
    return b

pv_model = pv.IntervenableModel({
    "component": "model.layers[15].mlp.output", # string access
    "intervention": zeroout_intervention_fn}, model=model)

# run the intervened forward pass
orig_outputs, intervened_outputs = pv_model(
    tokenizer("The capital of Spain is", return_tensors="pt").to('cuda'),
    output_original_output=True
)
print(intervened_outputs.logits - orig_outputs.logits)

which returns,

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.4375,  1.0625,  0.3750,  ..., -0.1562,  0.4844,  0.2969],
         [ 0.0938,  0.1250,  0.1875,  ...,  0.2031,  0.0625,  0.2188],
         [ 0.0000, -0.0625, -0.0312,  ...,  0.0000,  0.0000, -0.0156]]],
       device='cuda:0')

IntervenableModel Loaded from HuggingFace Directly

The following codeblock can reproduce honest_llama-2 chat from the paper Inference-Time Intervention: Eliciting Truthful Answers from a Language Model. The added activations are only ~0.14MB on disk!

# others can download from huggingface and use it directly
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pyvene as pv

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    torch_dtype=torch.bfloat16,
).to("cuda")

pv_model = pv.IntervenableModel.load(
    "zhengxuanzenwu/intervenable_honest_llama2_chat_7B", # the activation diff ~0.14MB
    model,
)

print("llama-2-chat loaded with interventions:")
q = "What's a cure for insomnia that always works?"
prompt = tokenizer(q, return_tensors="pt").to("cuda")
_, iti_response_shared = pv_model.generate(prompt, max_new_tokens=64, do_sample=False)
print(tokenizer.decode(iti_response_shared[0], skip_special_tokens=True))

With this, once you discover some clever intervention schemes, you can share with others quickly without sharing the actual base LMs or the intervention code!

IntervenableModel as Regular nn.Module

You can also use the pv_gpt2 just like a regular torch model component inside another model, or another pipeline as,

import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union, Dict

class ModelWithIntervenables(nn.Module):
    def __init__(self):
        super(ModelWithIntervenables, self).__init__()
        self.pv_gpt2 = pv_gpt2
        self.relu = nn.ReLU()
        self.fc = nn.Linear(768, 1)
        # Your other downstream components go here

    def forward(
        self, 
        base,
        sources: Optional[List] = None,
        unit_locations: Optional[Dict] = None,
        activations_sources: Optional[Dict] = None,
        subspaces: Optional[List] = None,
    ):
        _, counterfactual_x = self.pv_gpt2(
            base,
            sources,
            unit_locations,
            activations_sources,
            subspaces
        )
        return self.fc(self.relu(counterfactual_x.last_hidden_state))

Complex Intervention Schema as an Object

One key abstraction that pyvene provides is the encapsulation of the intervention schema. While abstraction provides good user-interfact, pyvene can support relatively complex intervention schema. The following helper function generates the schema configuration for path patching on individual attention heads on the output of the OV circuit (i.e., analyzing causal effect of each individual component):

import pyvene as pv

def path_patching_config(
    layer, last_layer, 
    component="head_attention_value_output", unit="h.pos", 
):
    intervening_component = [
        {"layer": layer, "component": component, "unit": unit, "group_key": 0}]
    restoring_components = []
    if not stream.startswith("mlp_"):
        restoring_components += [
            {"layer": layer, "component": "mlp_output", "group_key": 1}]
    for i in range(layer+1, last_layer):
        restoring_components += [
            {"layer": i, "component": "attention_output", "group_key": 1}
            {"layer": i, "component": "mlp_output", "group_key": 1}
        ]
    intervenable_config = IntervenableConfig(intervening_component + restoring_components)
    return intervenable_config

then you can wrap the config generated by this function to a model. And after you have done your intervention, you can share your path patching with others,

_, tokenizer, gpt2 = pv.create_gpt2()

pv_gpt2 = pv.IntervenableModel(
    path_patching_config(4, gpt2.config.n_layer), 
    model=gpt2
)
# saving the path
pv_gpt2.save(
    save_directory="./your_gpt2_path/"
)
# loading the path
pv_gpt2 = pv.IntervenableModel.load(
    "./tmp/",
    model=gpt2)

Selected Tutorials

Level Tutorial Run in Colab Description
Beginner pyvene 101 Introduce you to the basics of pyvene
Intermediate ROME Causal Tracing Reproduce ROME's Results on Factual Associations with GPT2-XL
Intermediate Intervention v.s. Probing Illustrates how to run trainable interventions and probing with pythia-6.9B
Advanced Trainable Interventions for Causal Abstraction Illustrates how to train an intervention to discover causal mechanisms of a neural model

Contributing to This Library

Please see our guidelines about how to contribute to this repository.

Pull requests, bug reports, and all other forms of contribution are welcomed and highly encouraged! :octocat:

A Little Guide for Causal Abstraction: From Interventions to Gain Interpretability Insights

Basic interventions are fun but we cannot make any causal claim systematically. To gain actual interpretability insights, we want to measure the counterfactual behaviors of a model in a data-driven fashion. In other words, if the model responds systematically to your interventions, then you start to associate certain regions in the network with a high-level concept. We also call this alignment search process with model internals.

Understanding Causal Mechanisms with Static Interventions

Here is a more concrete example,

def add_three_numbers(a, b, c):
    var_x = a + b
    return var_x + c

The function solves a 3-digit sum problem. Let's say, we trained a neural network to solve this problem perfectly. "Can we find the representation of (a + b) in the neural network?". We can use this library to answer this question. Specifically, we can do the following,

  • Step 1: Form Interpretability (Alignment) Hypothesis: We hypothesize that a set of neurons N aligns with (a + b).
  • Step 2: Counterfactual Testings: If our hypothesis is correct, then swapping neurons N between examples would give us expected counterfactual behaviors. For instance, the values of N for (1+2)+3, when swapping with N for (2+3)+4, the output should be (2+3)+3 or (1+2)+4 depending on the direction of the swap.
  • Step 3: Reject Sampling of Hypothesis: Running tests multiple times and aggregating statistics in terms of counterfactual behavior matching. Proposing a new hypothesis based on the results.

To translate the above steps into API calls with the library, it will be a single call,

intervenable.eval_alignment(
    train_dataloader=test_dataloader,
    compute_metrics=compute_metrics,
    inputs_collator=inputs_collator
)

where you provide testing data (basically interventional data and the counterfactual behavior you are looking for) along with your metrics functions. The library will try to evaluate the alignment with the intervention you specified in the config.


Understanding Causal Mechanism with Trainable Interventions

The alignment searching process outlined above can be tedious when your neural network is large. For a single hypothesized alignment, you basically need to set up different intervention configs targeting different layers and positions to verify your hypothesis. Instead of doing this brute-force search process, you can turn it into an optimization problem which also has other benefits such as distributed alignments.

In its crux, we basically want to train an intervention to have our desired counterfactual behaviors in mind. And if we can indeed train such interventions, we claim that causally informative information should live in the intervening representations! Below, we show one type of trainable intervention models.interventions.RotatedSpaceIntervention as,

class RotatedSpaceIntervention(TrainableIntervention):
    
    """Intervention in the rotated space."""
    def forward(self, base, source):
        rotated_base = self.rotate_layer(base)
        rotated_source = self.rotate_layer(source)
        # interchange
        rotated_base[:self.interchange_dim] = rotated_source[:self.interchange_dim]
        # inverse base
        output = torch.matmul(rotated_base, self.rotate_layer.weight.T)
        return output

Instead of activation swapping in the original representation space, we first rotate them, and then do the swap followed by un-rotating the intervened representation. Additionally, we try to use SGD to learn a rotation that lets us produce expected counterfactual behavior. If we can find such rotation, we claim there is an alignment. If the cost is between X and Y.ipynb tutorial covers this with an advanced version of distributed alignment search, Boundless DAS. There are recent works outlining potential limitations of doing a distributed alignment search as well.

You can now also make a single API call to train your intervention,

intervenable.train_alignment(
    train_dataloader=train_dataloader,
    compute_loss=compute_loss,
    compute_metrics=compute_metrics,
    inputs_collator=inputs_collator
)

where you need to pass in a trainable dataset, and your customized loss and metrics function. The trainable interventions can later be saved on to your disk. You can also use intervenable.evaluate() your interventions in terms of customized objectives.

Citation

If you use this repository, please consider to cite our library paper:

@article{wu2024pyvene,
  title={pyvene: A Library for Understanding and Improving {P}y{T}orch Models via Interventions},
  author={Wu, Zhengxuan and Geiger, Atticus and Arora, Aryaman and Huang, Jing and Wang, Zheng and Noah D. Goodman and Christopher D. Manning and Christopher Potts},
  booktitle={arXiv:2403.07809},
  url={arxiv.org/abs/2403.07809},
  year={2024}
}

Related Works in Discovering Causal Mechanism of LLMs

If you would like to read more works on this area, here is a list of papers that try to align or discover the causal mechanisms of LLMs.

Star History

Star History Chart

pyvene's People

Contributors

amirzur avatar aryamanarora avatar atticusg avatar eltociear avatar explanare avatar frankaging avatar jiudingsun01 avatar khoomeik avatar pinetreepantry avatar sungmincho avatar zhengpeterwang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

pyvene's Issues

[P1] Initial support of stateful model architectures such as RNN

Description:
This PR only includes minimum changes to the repo to make the library extensible to all types of RNN models. The key difference is that the hook function has to be stateful, i.e., it needs to be aware of its "step" when it is called. This ties back to the fact that RNN-based models or transformer model when generating sequences are acting like stateful models.

This change will ideally include a simple tutorial for how to do intervention with a very simple RNN at any step. The input fields should stay the same, but the hook needs to do bookkeeping of extra fields in the memory to do stateful interventions.

[P1] Streamlining trainable intervention artifacts saving and sharing

Description:
After training, the intervention's artifacts are saved in memory without a good way of saving to disk with other metadata or sharing on huggingface marketplace. This will be a change to provide a smooth way of saving/sharing interventions trained by users.

The key thing will be serializing metadata into a shareable format (i.e., serializing and deserializing need both be tested). It will still require sharing parties to know the counterfactual dataset generation, but it is less of a problem of this library and more about sharing the dataset itself. And dataset sharing could be a separate process not included in this library.

This change should also consider sharing interventions that contain a vector store (some truthful direction for sharing, etc..).

Testing Done:

  • Local Test Log:
.Removing testing dir ./test_output_dir_prefix-d9080f
Removing testing dir ./test_output_dir_prefix-dff621
Removing testing dir ./test_output_dir_prefix-9227e2
Removing testing dir ./test_output_dir_prefix-6cb8c4
Removing testing dir ./test_output_dir_prefix-67cd73

----------------------------------------------------------------------
Ran 25 tests in 4.280s

OK
  • New Tutorial Added tutorials/basic_tutorials/Load_Save_and_Share_Interventions.ipynb.

[P1] Support more huggingface (transformer-based) models

Descriptions:
Ideally, all the models listed here can be supported by this library without exposing the model details to the users of this library.

This requires we set up model folders for all model types and write config metadata for each of them annotating where to do interventions. This requires a lot of effort. This is a PR tracking the process towards the goal of supporting as many as we can.

Each model should take less than an hour to (1) configure and (2) write simple unit tests.

Here is the list of models that are in the pipeline to support (in order):

  • BERT-family
    • RoBERTa
    • DeBERTa
    • ELECTRA
  • xlm (multilingual model)
  • t5
  • Mistral
  • Mixtral (MoE, MixtralForCausalLM)
  • Phi
  • Mamba (but need to support recurrent interventions, not just layerwise interventions)
  • backpack-gpt2
  • please feel free to suggest other new models to support!

[P1] Supports simple use case of intervention (fixed location intervention for all examples)

Descriptions:
The library is built to support flexible intervention schemes. It comes at the cost of easiness to onboard. To make the library easier for simpler use cases, we want to support more fields in the intervenable config for "more static" intervention schemes.

For instance, for all examples, we want to intervene with the 4th token in the input at 6th layer of the transformer block output. With this, we don't need to provide intervention locations at runtime. We can simply provide a pair of base and source examples.

[P1] Introduce grouped interventions where each group is associated with the same source input

Description:
Currently, we assume there is one source example associated with one intervention. Often, we want to reuse the same source example for multiple interventions, where these interventions are grouped for a purpose.

We want to support the concept of "grouping" during intervention configuration, as well as source inputs. With this, you can also flexibly skip some groups if you don't want to.

Changes need to be made in the common classes regarding the config, as well as how alignable class consumes the source inputs.

LlamaForCausalLM (case sensitive issue?)

https://github.com/frankaging/align-transformers/blob/6f2b7d7ad8203d263cd20b6da7a6be6171d1e690/models/modelings_alignable.py#LL11C2-L11C2

Could you maybe add "LlamaForCausalLM" here? (It's not captured by "LLaMAForCausalLM")

My alpaca weights recovered as guided in https://github.com/tatsu-lab/stanford_alpaca has the following config for some reason:

LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 2048,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-06,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "4.29.2",
  "use_cache": true,
  "vocab_size": 32001
}

BTW, thanks for this awesome project!

[P2] Mac chip MPS mode support

Descriptions:
Currently, matrix exponential will fail with M1 chip MPS framework for rotation-based interventions. Need to figure out other ways to handle this so that M1 chip can run pyvene with trainable interventions like DAS as well.

[P1] Support multi-variable alignment with a single trainable basis intervention

Description:
One limitation with the current codebase is that, for DAS or anything similar, we only support aligning one causal variable at a time. This breaks the assumption that we are using DAS to find a new basis from which we can interpret orthogonal causal variables among different axes in that learned basis! Learning separate basis for different causal variables break this core assumption.

We need to have a new type of intervention that supports multi-variable alignments. And this is more like a special need for DAS, or other basis-respect interventions.

Impact files may include intervention, and alignable config as well as the input fields needed to show causal variable index.

[P2] Refactor the AlignableConfig to take in intervention not as type but actual instance

Description:
Currently, we have,

class AlignableConfig(PretrainedConfig):
    def __init__(
        self,
        alignable_model_type="gpt2",
        alignable_representations=[
            # we do distributed search over elements in the sublist.
            AlignableRepresentationConfig()
        ],
        alignable_interventions_type=VanillaIntervention,
        alignable_low_rank_dimension=None,
        mode="parallel",
        **kwargs
    ):

We need to specify the type as a class type, not as an instance. This causes trouble. It is better to get alignable interventions as a list of actual instances, e.g., alignable_interventions = [VanillaIntervention()]. This will allow us to have more specifications for customizable interventions.

[P2] Probe training with interventions

Description:
The training for interventions is done through backprop from the model loss. There are other use cases such as intervention itself could be supervised by other objectives such as probes attached to the intervention site.

We can add a basic classification head on top of the intervention object so that additional gradients can be backprop through the probe to the interventions. This could allow this library to support basic probing experiments with desired class labels.

Bug in BoundlessRotatedSpaceIntervention

In the new version, self.embed_dim in line 262 in interventions.py is initialized as None and never assigned value. This will cause the failure of running Doundless_DAS.ipynb:

Traceback (most recent call last):
File "/work/frink/sun.jiu/function_vectors/src/compute_rotational_subspace.py", line 300, in
intervenable = IntervenableModel(intervenable_config, model)
File "/work/frink/sun.jiu/miniconda3/envs/fv/lib/python3.10/site-packages/pyvene/models/intervenable_base.py", line 111, in init
intervention = intervention_function(
File "/work/frink/sun.jiu/miniconda3/envs/fv/lib/python3.10/site-packages/pyvene/models/interventions.py", line 265, in init
torch.arange(0, self.embed_dim), requires_grad=False
TypeError: arange() received an invalid combination of arguments - got (int, NoneType), but expected one of:

  • (Number end, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
  • (Number start, Number end, *, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
  • (Number start, Number end, Number step, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

Simply changing the line into:

self.embed_dim = embed_dim

Would solve the issue.

[P0] Adding more intervention support in generate mode

Description:
Currently, in the generate mode, the code only works if you intervene on the first prompt token, and every decoding token, given the activation caching mechanism of the huggingface library.

We want to provide more generic support: 1) intervene on different tokens in the prompt; 2) every decoding step.

[P0] enable `use_fast` option in the alignable to hyper boost training speed in case intervention locations (for position+subspace) are fixed in a batch

Description:
Currently, the library aims for flexibility in the inputs as well as a small training batch size in case the intervention is trainable. For instance, we assume each example in the batch can have different intervention locations as well as different intervention subspaces allowing more flexible configurations.

This is not desired when we have a large batch size, and intervention location does not change within a batch. Suppose we want to localize (a+b) with a simple NN that solves (a+b)*c, and we want to localize (a+b) with DAS and a fixed dimensionality of 16, the intervention location stays the same. However, current code will actually do the intervention in the example-level, not in the batch level. See,

for batch_i, locations in enumerate(unit_locations):
    tensor_input[
        batch_i, locations, start_index:end_index
    ] = replacing_tensor_input[batch_i]

this can be,

tensor_input[
    :, locations, start_index:end_index
] = replacing_tensor_input[:]

subspace intervention,

    if subspaces is not None:
        for example_i in range(len(subspaces)):
            # render subspace as column indices
            sel_subspace_indices = []
            for subspace in subspaces[example_i]:
                sel_subspace_indices.extend(
                    [
                        i for i in range(
                            subspace_partition[subspace][0], 
                            subspace_partition[subspace][1]
                        )
                    ])
            if mode == "interchange":
                base[example_i, ..., sel_subspace_indices] = \
                    source[example_i, ..., sel_subspace_indices]
            elif mode == "add":
                base[example_i, ..., sel_subspace_indices] += \
                    source[example_i, ..., sel_subspace_indices]
            elif mode == "subtract":
                base[example_i, ..., sel_subspace_indices] -= \
                    source[example_i, ..., sel_subspace_indices]

can be,

if subspaces is not None:
    if subspace_partition is None:
        sel_subspace_indices = subspaces[0]
    else:
        sel_subspace_indices = []
        for subspace in subspaces[0]:
            sel_subspace_indices.extend(
                [
                    i for i in range(
                        subspace_partition[subspace][0], 
                        subspace_partition[subspace][1]
                    )
                ])
    if mode == "interchange":
        base[..., sel_subspace_indices] = \
            source[..., sel_subspace_indices]
    elif mode == "add":
        base[..., sel_subspace_indices] += \
            source[..., sel_subspace_indices]
    elif mode == "subtract":
        base[..., sel_subspace_indices] -= \
            source[..., sel_subspace_indices]
else:
    base[..., :interchange_dim] = source[..., :interchange_dim]

We should enable a flag as use_fact in the alignable config, and do a validation check that fails fast during the forward call.

This PR tracks the use_fast effort for position-based intervention as well as subspace-based intervention. It does not cover head-based or head+position-based yet. Will cover the latter one in a separate PR.

Testing Done:

  • writing additional integration tests (4)
  • log:
In case multiple location tags are passed only the first one will be considered
testing stream: value_output with a single position
WARNING:root:Detected use_fast=True means the intervention location will be static within a batch.

In case multiple location tags are passed only the first one will be considered
.
----------------------------------------------------------------------
Ran 18 tests in 30.117s

OK

[P1] Optionally remove the dependency of the config file

Descriptions:
Currently, all the models need to have a config, and the config needs to be inheriting transformer library's config. The model config is only used in,

                intervention = intervention_function(
                    embed_dim=get_dimension(
                        get_internal_model_type(model), model.config, representation
                    ), **other_medata 
                )

to get the components dimension.

We should however allow config-less models where the dimension is directly read-off from the config dict, or dynamically figure out using some helper functions.

  • Case 1:
    Hard code number
mlp_type_to_dimension_mapping = {
    "block_input": (32,),
    "block_output": (32,),
    "mlp_activation": (32,),
}
  • Case 2:
    pyvene dynamically figures out the input and output dimensions of all the modules. (checkout torch.compile)

Meanwhile, the intervening component can accept arbitrary model component string e.g. model.h[2].attn.c_proj.output, we can dynamically figure the component out.

[P1] Supporting Mistral with Unit Tests (#46)

Descriptions:
Mistral model shares architectures with other models, e.g., gpt-2 and also Mixtral. Supporting Mistral along with unit tests will help us to support models in this family.

[P2] Refactor utils in models to smaller files

Description:
Now everything is in the util file, including common model config, import as well as different helpers on hooks, etc.. It is better to separate them out into smaller files to increase readability and extensibility.

[P2] Support argument name based intervention

Description:
When using the hook, we can now support kwargs-based inputs by reading the input as a dictionary. However, we will always assume the dictionary only contains a single input (e.g., hidden representations). This assumption can easily go wrong. What should we do instead is to specify which part of the inputs we do interventions on in the config for the model.

Note that this will still result in coupled code with the Transformers library. Multiple PRs are required to move towards this direction.

[P2] Having a simple tutorial on how to use intervenable model inside another torch module, or pipeline (e.g., HF pipeline)

Descriptions:
The IntervenableModel is a torch.nn.Module. So, this can be used inside another torch model, or even pipeline object (e.g., Huggingface pipeline). Here is a quick code snippet,

import pyvene
from pyvene import IntervenableRepresentationConfig, IntervenableConfig, IntervenableModel

# provided wrapper for huggingface gpt2 model
_, tokenizer, gpt2 = pyvene.create_gpt2()

# turn gpt2 into intervenable_gpt2
intervenable_gpt2 = IntervenableModel(
    intervenable_config = IntervenableConfig(
        intervenable_representations=[
            IntervenableRepresentationConfig(
                0,            # intervening layer 0
                "mlp_output", # intervening mlp output
                "pos",        # intervening based on positional indices of tokens
                1             # maximally intervening one token
            ),
        ],
    ), 
    model = gpt2
)

import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union, Dict

class ModelWithIntervenables(nn.Module):
    def __init__(self):
        super(ModelWithIntervenables, self).__init__()
        self.intervenable_gpt2 = intervenable_gpt2
        self.relu = nn.ReLU()
        self.fc = nn.Linear(768, 1)
        # Your other downstream components go here

    def forward(
        self, 
        base,
        sources: Optional[List] = None,
        unit_locations: Optional[Dict] = None,
        activations_sources: Optional[Dict] = None,
        subspaces: Optional[List] = None,
    ):
        _, counterfactual_x = self.intervenable_gpt2(
            base,
            sources,
            unit_locations,
            activations_sources,
            subspaces
        )
        counterfactual_x = counterfactual_x.last_hidden_state
        
        counterfactual_x = self.relu(counterfactual_x)
        counterfactual_x = self.fc(counterfactual_x)
        return counterfactual_x

and then you can run forward as usual,

model = ModelWithIntervenables()

base = tokenizer("The capital of Spain is", return_tensors="pt")
sources = [
    tokenizer("The capital of Italy is", return_tensors="pt"),
]

model(
    base, sources, {"sources->base": ([[[4]]], [[[4]]])}
)

which returns,

tensor([[[2.7027],
         [6.3036],
         [6.1785],
         [6.4302],
         [8.0921]]], grad_fn=<ViewBackward0>)

[P1] Unit test with hand-crafted models

Description:
Currently, we don't have a systematic way of unit testing the library. A good way is just to create some hand-crafted models (i.e., with fixed weights), do interventions, and check counterfactual behaviors.

Probably not hand-crafted transformers, but just simple MLPs with 3 hidden states, for instance.

[P2] Support simple MLP layer for interventions

Description:
Currently, the library only works for transformer-based models. For non-sequence-based models, MLP models; or other sequence-based models like RNN, the library cannot work well.

The first step moving forward to support other model types could be to showcase how this library will work for MLP models. The MLP model can be hand-crafted as well so that we know the counterfactual behaviors. We expect there will be hacks here and there to get things to work, but it will allow more model types.

[External] Counterfactual sampling in price-tagging tutorial utils code

Hello,

I was wondering what exactly the counterfactual sampling procedure in lower_bound_alignment_example_sampler does. Do the base and counterfactual labels have to be different, or can they be the same? For example, for a counterfactual label like "No", do we only want to sample base and source amounts such that the base label "Yes" is changed to "No" after intervention, or can the base label also be "No"? The code seems to suggest the latter scenario.

In that case, I put in in-line comments for what seems like a potential bug. When base_source_regions is [2, 3], The base left and right boundary values are (Yes, Yes), and that of the source is (Yes, No). The base label is "Yes", but after intervening on the left boundary, it is still "Yes". Any clarification is much appreciated!

alignment_bug

[P2] Supports list of subspace indices instead of chunk partition

Description:
Currently, we are enabling interventions to take in a new argument, subspaces so that interventions can only intervene on selected subspaces.

Right now, the subspace is partitioned into different chunks during configurations. This makes interventions that work with discrete neuron level subspace hard.

We need to change the intervention to accept subspaces as a list of dimensions e.g., [[1,3,5]] meaning we are intervening on 1st, 3rd, and 5th neurons.

Error when running run_alignment.py: has no attribute when trying to save rotation matrix.

Hi, I am working with your newest version of the repo and got the tutorial.ipynb to work. However, when I run run_alignment.py with the training script at the end at your README.md, I run into the following error as the model tries to save checkpoints of the rotation layer:

Traceback (most recent call last):
  File "/net/scratch/zhouy1/github/align-transformers-forked/run_alignment.py", line 183, in <module>
    aligner.train(
  File "/net/scratch/zhouy1/github/align-transformers-forked/trainer.py", line 222, in train
    self.save_model(output_dir, 'pytorch-rotate-best.bin')
  File "/net/scratch/zhouy1/github/align-transformers-forked/trainer.py", line 62, in save_model
    'rotate_layer': self.model.module.model.rotate_layer.state_dict(),
  File "/home/zhouy1/miniconda3/envs/BoundlessDAS/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1269, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'AlignableLlamaForCausalLM' object has no attribute 'module'

Thanks!

[Bug]: pyvene.ai is taking too long to respond

Contact Details

[email protected]

What happened?

Receiving error message when visiting "Pyvene.ai" website. The message says the site can't be reached at the moment. It's taking too long to respond.

Code to produce this issue.

Visit https://pyvene.ai/ in Browser

Chrome : Version 123.0.6312.58 (Official Build) (64-bit)
OS : Win10 Pro for Workstations

[P2] Quantized model support

Descriptions:
The library supports interventions on the torch model. Interventions can be attached to any subcomponents in the torch model. However, for a quantized model where the model is quantized with another wrapper, the intervention location can be more dynamic. See the quantized function here.

We want to support interventions on quantized models and intervene on correct components specified in the configuration.

[P0] Make interface compatible with HF trainer

Suggestion / Feature Request

Need to make sure HF trainer can easily hook onto pv.IntervenableModel. This means getting the right parameters for the optimiser, making sure model state is set correctly, etc.

[P1] Speed up training of multiple DAS interventions with caching

Commonly, we want to exhaustively train DAS on every layer and position (or e.g. every attention head in a layer) to find which ones are causally relevant for the model's computations. When dealing with a fixed dataset, we could speed this process up by caching and reusing activations. Unclear what the best way to implement this is; should already be possible to have a minimal example with CollectIntervention and the activations_sources= arg in inference.

[P2] Multi-GPU model sharding with intervening evaluation and training

Descriptions:

The library is not tested with multi-GPU use cases. We assume the intervening model can be loaded into a single GPU. This is not ideal for interventions on 70B models, for instance. We want to be able to load the model into multiple GPUs using sharding.

Static interventions need to be attached to the right component on the right machine in case of model sharing. Training interventions need to be mapped onto the right machine where the corresponding model component lives as well.

This could be a large task. The first step is clear: try out static interventions (e.g., vanilla interventions) when models are loaded into multiple GPUs during inference time.

[P2] Enable `use_fast` for head related interventions

Descriptions:
The first PR to enable use_fast (https://github.com/frankaging/align-transformers/issues/33) does not cover head-related interventions. We want to enable this for head index + position index as well when used compositionally.

The following code needs to be changed,
https://github.com/frankaging/align-transformers/blob/main/models/modeling_utils.py#L464

        if "head" in alignable_representation_type:
            start_index = 0 if start_index is None else start_index
            end_index = 0 if end_index is None else end_index
            # head-based scattering
            if alignable_unit in {"h.pos"}:
                # we assume unit_locations is a tuple
                for head_batch_i, head_locations in enumerate(unit_locations[0]):
                    for head_loc_i, head_loc in enumerate(head_locations):
                        for pos_loc_i, pos_loc in enumerate(unit_locations[1][head_batch_i]):
                            h_start_index = start_index+head_loc*attn_head_size
                            h_end_index = start_index+(head_loc+1)*attn_head_size
                            tensor_input[
                                head_batch_i, pos_loc, h_start_index:h_end_index
                            ] = replacing_tensor_input[head_batch_i, head_loc_i, pos_loc_i] # [dh]
            else:
                for batch_i, locations in enumerate(unit_locations):
                    for loc_i, loc in enumerate(locations):
                        h_start_index = start_index+loc*attn_head_size
                        h_end_index = start_index+(loc+1)*attn_head_size
                        tensor_input[
                            batch_i, :, h_start_index:h_end_index
                        ] = replacing_tensor_input[batch_i, loc_i] # [s, dh]

[P2] Support interchange intervention training by unfreezing model weights with vanilla intervention

Description:
Currently, we assume model weights are frozen when training intervention for alignments. We can also add support to this library so that models can be tuned with the intervention.

This can help reproduce interchange intervention training experiments in this paper. Or it can be used to reproduce experiments in the causal proxy model (i.e., using another explainer model to explain a Blackbox model)

[P2] Add a new huggingface collator for working with Pyvene models

Suggestion / Feature Request

Pyvene is a library featuring interchange interventions. It frequently needs to process datasets that contain two sets of input_ids and (maybe) two sets of labels. When we need to train these libraries with batched datasets, the collator issue starts to arise: there is no existing collator that supports padding both sets of input_ids of different lengths at the same time.

Hugging face transformers only pad the "input_ids" entries in the dataset

In addition to above, DataCollatorForSeq2Seq only pads "labels".

So dataset entries like "source_input_ids" are not padded, a problematic issue.

Adding a utility supporting this may help pyvene develop in general.

[External] Save Interventions

Hi,

I am recently exploring the repo and using Boundless DAS, is there a way to save and load the interventions? Thanks!

[P1] Dynamic Intervention Scheduler

Descriptions:
Currently, we only support basic interventions during model generation just like the model forward call. This is not ideal. In model generation, we want to support more free-formed interventions (e.g., intervene based on decoding steps or other decoding parameters, not just unit location as if it is in an intervened forward mode).

The current infra (also this applies to other existing intervention library as well) cannot support this. For instance, it does not support a specific decoding step intervention during decoding without more incisive code change. To support complex cases, we plan to introduce a new notion of Intervention Scheduler.

In the high-level, the scheduler is responsible to schedule interventions dynamically at inference time, and it is customizable. For instance, we can (1) intervene on all decoded punctuation tokens, or (2) all verbs that get decoded, or (3) all the last entity token that gets decoded in a specific entity set.

This enables us to a wide spectrum of ways to steer model behavior with interventions. This ticket may require multiple changes.

[P1] Adding support to intervene selected positions of selected heads

Description:
There is a clear use-case of intervening in certain positions of selected heads (e.g., intervening x-th token in head y). For instance, we can see how different heads at each position handle the information. We can see where there is an induction head on top of certain positions.

We probably need multiple changes to achieve the goal without bugs. This ticket marks the first step of it. We will support basic nested intervention locations. Specifically, we want to have the following capability,

        _, counterfactual_outputs = alignable(
            base,
            sources,
            {"sources->base": ([
                [[[target_head]], [[pos_i]]] # intervene w/ target_head's pos_i
            ], [
                [[[target_head]], [[pos_i]]] # intervene on target_head's pos_i
            ])}
        )

where target_head is a list of specific heads, and we want to intervene on a list of pos_i for each head in the target_head. With this, we can intervene on 3rd token representation from 4th head.

[P2] Intervening through time recurrently without teacher-forcing

Descriptions:
Currently, we support a limited intervention use case on stateful models such as GRU. For instance, after the intervention, although the causal effect of the intervening site would ripple through time, we assume the inputs to be the same as before intervention. This is fine if the task setup doesn't care about the inputs, or it is simply input agnostic when generating, or allows teacher forcing (forcing to discard) when generating.

Here are some illustrations. Right now, we can support cross-time interventions as,

example 1:
(hiddens)    h1, h2, h3, h4, h5, h6
(inputs)       x1, x2,  x3, x4, x5, x6
                    ^
                     |
                     ---
                       |
                       v
example 2:
(hiddens)    h1', h2', h3', h4', h5', h6'
(inputs)       x1', x2',  x3', x4', x5',  x6'

where we take h3' from the second example to intervene in h2 from the first example through time. We then also update h3 to h6 after the intervention. However, we assume x3 to x6 still use the inputs from the example 1. This is acceptable, if during training, x3 to x6 are agnostic in terms of the model's generation (e.g., x2 is some trigger token so the model is in generation mode).

However, this is not ideal. Ideally, if we are dealing with autoregressive LMs, we want x3 to be the intervened model output at the previous step. This requires the model to pass gradients through time. One simple solution is to update the model to do Gumbel-Softmax to softly select the token and pass it to the next time step as the input embedding.

The change may be only on the modeling side. We need to change the model to do soft token selection which allows gradients. However, this is compatible with the library since only in intervention mode, does this input-based unrolling make sense.

[P0] Library Class Renaming Effort

Descriptions:
We will move away from the concept of "align" soon in light of future releases of this library. We thus change "alignable" to "intervenable" for all the occurrences.

Testing Done:

  • Unit Tests Log:
.Removing testing dir ./test_output_dir_prefix-185ec0
Removing testing dir ./test_output_dir_prefix-f2e887
Removing testing dir ./test_output_dir_prefix-c7b84c
Removing testing dir ./test_output_dir_prefix-9e7acd
Removing testing dir ./test_output_dir_prefix-762318

----------------------------------------------------------------------
Ran 25 tests in 8.546s

OK
  • String Match Find
zhengxuanwu@DNa8211a3 align-transformers % grep -Ri "alignable" .
Binary file ./.git/objects/pack/pack-96bfa18d843e36c6082333ae210e2c8c92ce2bbd.pack matches
  • Making sure all the tutorials pass (currently doing)

[P0] Upgrade tutorials to new API

A lot of the old tutorials are using the old API for representation configs and interventions; some of them might even be broken because of this. We should upgrade all of them to the new format to showcase how the library is designed to be used. Thanks to @smejak for pointing this out.

Edit: Specifically, intro DAS tutorial isn't working in colab, intervened forward pass is failing. Will debug this later today.

[P1] Unit tests for main class objects

Descriptions:
Currently, only basic integration tests around modules are included. It would be best if individual functions could be tested in a single unit test file separating from the integration test.

Testing Done:

  • Local Test Log
.Removing testing dir ./test_output_dir_prefix-d9080f
Removing testing dir ./test_output_dir_prefix-dff621
Removing testing dir ./test_output_dir_prefix-9227e2
Removing testing dir ./test_output_dir_prefix-6cb8c4
Removing testing dir ./test_output_dir_prefix-67cd73

----------------------------------------------------------------------
Ran 25 tests in 4.280s

OK

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.