Git Product home page Git Product logo

jbloomaus / decisiontransformerinterpretability Goto Github PK

View Code? Open in Web Editor NEW
58.0 4.0 15.0 52.35 MB

Interpreting how transformers simulate agents performing RL tasks

Home Page: https://jbloomaus-decisiontransformerinterpretability-app-4edcnc.streamlit.app/

License: MIT License

Shell 0.10% Python 6.70% HTML 5.19% Jupyter Notebook 88.01%
reinforcement-learning mechanistic-interpretability

decisiontransformerinterpretability's Introduction

Decision Transformer Interpretability

build Code style: black

This project is designed to facilitate mechanistic interpretability of decision transformers as well as RL agents using transformer architectures.

This is achieved by:

  • Training scripts for online RL agents using the PPO algorithm. This training script can be used to generate trajectories for training a decision transformer.
  • A decision transformer implementation and training script. This implementation is based on the transformer architecture and the decision transformer architecture.
  • A streamlit app. This app enables researchers to play minigrid games whilst observing the decision transformer's predictions/activations.

Future work will include:

  • creating an interpretability portfolio, expanding various exploratory techniques already present in the streamlit app.
  • solving tasks which require memory or language instruction. Many MiniGrid tasks require agents have memory and currently our PPO agent only responds to the last timestep.
  • validating hypotheses about model circuits using casual scrubbing.

Write Up

You can find an initial technical report for this project here.

Package Overview

The package contains several important components:

  • The environments package which provides utilities for generating environments (mainly focussed on MiniGrid).
  • The decision_transformer package which provides utilities for training and evaluating decision transformers (via calibration curves).
  • The ppo package which provides utilities for training and evaluating PPO agents.
  • The streamlit app which provides a user interface for playing games and observing the decision transformer's predictions/activations.
  • The models package which provides the a common trajectory-transformer class so as to keep architectures homogeneous across the project.

Other notable files/folders:

  • The scripts folder contains bash scripts which show how to use various interfaces in the project.
  • The test folder which contains extensive tests for the projcect.

Example Results

We've successfully trained a decision transformer on several games including DoorKey and Dynamic Obstacles.

Calibration Plot MiniGrid-Dynamic-Obstacles-8x8-v0, after 6000 batch, episode length 14, RTG 1.0, reward 0.955

I highly recommend playing with the streamlit app if you are interested in this project. It relies heavily on an understanding of the Mathematical Framework for Transformer Circuits.

Running the scripts

Example bash scripts are provided in the scripts folder. They make use of argparse interfaces in the package.

Training a PPO agent

If you set 'track' to true, a weights and biases dashboard will be generated. A trajectories pickle file will be generated in the trajectories folder. This file can be used to train a decision tranformer.

python -m src.run_ppo --exp_name "Test" \
    --seed 1 \
    --cuda \
    --track \
    --wandb_project_name "PPO-MiniGrid" \
    --env_id "MiniGrid-DoorKey-8x8-v0" \
    --view_size 5 \
    --total_timesteps 350000 \
    --learning_rate 0.00025 \
    --num_envs 8 \
    --num_steps 128 \
    --num_minibatches 4 \
    --update_epochs 4 \
    --clip_coef 0.2 \
    --ent_coef 0.01 \
    --vf_coef 0.5 \
    --max_steps 1000 \
    --one_hot_obs

Training a decision transformer

Targeting the trajectories file and setting the model architecture details and hyperparameters, you can run the decision transformer training script.

python -m src.run_decision_transformer \
    --exp_name MiniGrid-Dynamic-Obstacles-8x8-v0-Refactor \
    --trajectory_path trajectories/MiniGrid-Dynamic-Obstacles-8x8-v0bd60729d-dc0b-4294-9110-8d5f672aa82c.pkl \
    --d_model 128 \
    --n_heads 2 \
    --d_mlp 256 \
    --n_layers 1 \
    --learning_rate 0.0001 \
    --batch_size 128 \
    --train_epochs 5000 \
    --test_epochs 10 \
    --n_ctx 3 \
    --pct_traj 1 \
    --weight_decay 0.001 \
    --seed 1 \
    --wandb_project_name DecisionTransformerInterpretability-Dev \
    --test_frequency 1000 \
    --eval_frequency 1000 \
    --eval_episodes 10 \
    --initial_rtg -1 \
    --initial_rtg 0 \
    --initial_rtg 1 \
    --prob_go_from_end 0.1 \
    --eval_max_time_steps 1000 \
    --track True

Note, if you want the training data from the blog post, you can download it like so

cd trajectories
gdown 1UBMuhRrM3aYDdHeJBFdTn1RzXDrCL_sr

Running the Streamlit app

To run the Streamlit app:

streamlit run app.py

To run the Streamlit app on Docker, see the Development section.

Setting up the environment

I haven't been too careful about this yet. Using python 3.9.15 with the requirements.txt file. We're using the V2 branch of transformer lens and Minigrid 2.1.0.

conda env create --name decision_transformer_interpretability python=3.9.15
conda activate decision_transformer_interpretability
pip install -r requirements.txt

The docker file should work and we can make use of it more when the project is further ahead/if we are alternativing developers frequently and have any differential behavior.

./scripts/build_docker.sh
./scripts/run_docker.sh

Then you can ssh into the docker and a good ide will bring credentials etc.

Development

Docker

If you're having trouble making the environment work, I recommend Docker. There's a dockerfile in the main folder - it takes a few minutes the first time, and 10-15 seconds for me when only changing code. If adding requirements it may take a bit longer. I (Jay) use Ubuntu through WSL and Docker Desktop, and it worked pretty easily for me.

To run it, first navigate to your project directory, then:

docker build -t IMAGE_NAME .
docker run -d -it -v $(pwd):/app --name CONTAINER_NAME IMAGE_NAME bash

To reset the container (e.g, you've changed the code, and want to rerun your tests), use:

docker stop CONTAINER_NAME
docker rm CONTAINER_NAME
docker rmi IMAGE_NAME
docker build -t IMAGE_NAME .
docker run -p 8501:8501 -d -it -v $(pwd):/app --name CONTAINER_NAME IMAGE_NAME bash

I recommend setting this all up as a batch command so you can do it easily for a quick iteration time.

Finally, to run a command, use:

docker exec CONTAINER_NAME COMMAND

For instance, to run unit tests, you would use docker exec CONTAINER_NAME pytest tests/unit.

To run Streamlit on your local browser, you can use the following command:

docker exec CONTAINER_NAME streamlit run app.py --server.port=8501

Tests:

Ensure that the run_tests.sh script is executable:

chmod a+x ./scripts/run_tests.sh

Run the tests. Note: the end to end tests are excluded from the run_test.sh script since they take a while to run. They make wandb dashboards are are useful for debugging but they are not necessary for development.

To run end-to-end tests, you can use the command 'pytest -v --cov=src/ --cov-report=term-missing'. If the trajectories file 'MiniGrid-Dynamic-Obstacles-8x8-v0bd60729d-dc0b-4294-9110-8d5f672aa82c.pkl' is not found in the tests, the 'gdown' command has failed to download it. In that case, download it manually or run 'conda install -c conda-forge gdown' and try again.

./scripts/run_tests.sh

You should see something like this after the tests run. This is the coverage report. Ideally this is 100% but we're not there yet. Furthermore, it will be 100% long before we have enough tests. But if it's 100% and we have performant code with agents training and stuff otherwise working, that's pretty good.

---------- coverage: platform darwin, python 3.9.15-final-0 ----------
Name                                Stmts   Miss  Cover   Missing
-----------------------------------------------------------------
src/__init__.py                         0      0   100%
src/decision_transformer.py           132      8    94%   39, 145, 151, 156-157, 221, 246, 249
src/ppo.py                             20     20     0%   2-28
src/ppo/__init__.py                     0      0   100%
src/ppo/agent.py                      109     10    91%   41, 45, 112, 151-157
src/ppo/compute_adv_vectorized.py      30     30     0%   1-65
src/ppo/memory.py                      88     11    88%   61-64, 119-123, 147-148
src/ppo/my_probe_envs.py               99      9    91%   38, 42-44, 74, 99, 108, 137, 168
src/ppo/train.py                       69      6    91%   58, 74, 94, 98, 109, 113
src/ppo/utils.py                      146     54    63%   41-42, 61-63, 69, 75, 92-96, 110-115, 177-206, 217-235
src/utils.py                           40     17    58%   33-38, 42-65, 73, 76-79
src/visualization.py                   25     25     0%   1-34
-----------------------------------------------------------------
TOTAL                                 758    190    75%

Next Steps

  • Getting PPO to work with a transformer architecture.
  • Analyse this model/the decision transformer/a behavioural clone and publish the results.
  • Get a version of causal-scrubbing working
  • Study BabyAI (adapt all models to take an instruction token that is prepended to the context window)

Relevant Projects:

decisiontransformerinterpretability's People

Contributors

dalasnoin avatar echowne avatar felhof avatar jay-bailey avatar jaybaileycs avatar jbloomaus avatar mjahaha avatar ruphail avatar vlfernandez avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

decisiontransformerinterpretability's Issues

Load and sample from model checkpoints

Thanks to #27 we have model checkpoints from training. I'd like to do the following:

  • Load these state dictionaries into a blank model
  • Evaluate them
  • Samples form them and write trajectories files which are much smaller than the training runs are for these bigger/messier tasks.

Implement another architecture such as GRU or LSTM for the PPO agent to generate trajectories for DT training

My initial plan was to look at how decision transformers solved various tasks and compare this to trajectory transformers training with online RL (via PPO), however this may be difficult due to the instability of standard transformers.

On Transforming Reinforcement Learning with Transformers: The Development Trajectory - https://arxiv.org/pdf/2212.14164.pdf
STABILIZING TRANSFORMERS FOR REINFORCEMENT LEARNING - https://arxiv.org/pdf/1910.06764.pdf

In light of this, to make it easier to generate training data for the decision transformers. I want to try another architecture such as LSTM.

Emulating the architecture from the BabyAI paper would probably be a good way to go (handling mission statements would be nice)- https://github.com/mila-iqia/babyai/tree/master

Tasks:

  • Implement an architecture such as LSTM in an agent class (such as LSTMAgent)
  • Train this class with PPO
  • Pass tests cases as well as showing it is performant on tasks like Memory Task and others
  • Ensure that the resulting trajectories can be used to train decision transformers.

Implement Wandb sweeps for decision transformer training.

At some point, I implemented weights and biases sweeps to find good hyperparameters when doing PPO training. I suspect this will be very useful in the future and would like it to also exist for the decision transformer script as well.

Examples from PPO:

Tasks:

  • make sure the PPO sweep works (I think there might be some funkiness with passing in the exp-name so might want to fix that first.
  • copy the pattern and implement it for scripts/run_decision_transformer.sh

Minigrid: Submit ViewSizeWrapper and RenderResizeWrapper PRs to Minigrid

I wrote a few wrappers for minigrid environments. which should probably live in the Minigrid github repo.

If someone has the time to submit that as a PR and write whatever tests their maintainers require (probably not a huge amount, then this would be greatly appreciated.

Checklist:
[ ] make a PR to Minigrid with the wrappers (double check the view size wrapper doesn't exist, it could be that it just had a bug that I didn't want to wait for them to fix)
[ ] add tests to the PR
[ ] have the PR accepted
[ ] when the PR is accepted, update this repo to use the new minigrid version and remove the wrapper code from this repo.

Reimplement BabyAI Recurrent AC Model for Trajectory Generation

After failing to generate a successful TransformerAC Model, I'm pivoting to generating a recurrent AC model based on the Baby Pytorch implementation (https://github.com/mila-iqia/babyai/blob/master/babyai/model.py)

The tasks involved are:

  • Port over the model and get it running using our PPO methods, rewrite any if needed
  • Write basic tests to ensure it works
  • Ensure we can generate trajectories from it that we can train our decision transformers on

Minigrid: Update DictObservationSpaceWrapper to only modify the observation space "mission" entry or write another wrapper to do so

It really bothers me that the minigrid wrapper for tokenising the observations also modifies the observations space for images. These should be untied.

class DictObservationSpaceWrapper(ObservationWrapper):
    """
    Transforms the observation space (that has a textual component) to a fully numerical observation space,
    where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.

    This wrapper is not applicable to BabyAI environments, given that these have their own language component.

    Example:
        >>> import miniworld
        >>> import gymnasium as gym
        >>> import matplotlib.pyplot as plt
        >>> from minigrid.wrappers import DictObservationSpaceWrapper
        >>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
        >>> obs, _ = env.reset()
        >>> obs['mission']
        'avoid the lava and get to the green goal square'
        >>> env_obs = DictObservationSpaceWrapper(env)
        >>> obs, _ = env_obs.reset()
        >>> obs['mission'][:10]
        [19, 31, 17, 36, 20, 38, 31, 2, 15, 35]
    """

    def __init__(self, env, max_words_in_mission=50, word_dict=None):
        """
        max_words_in_mission is the length of the array to represent a mission, value 0 for missing words
        word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
                  if None, use the Minigrid language
        """
        super().__init__(env)

        if word_dict is None:
            word_dict = self.get_minigrid_words()

        self.max_words_in_mission = max_words_in_mission
        self.word_dict = word_dict

        image_observation_space = spaces.Box(
            low=0,
            high=255,
            shape=(self.agent_view_size, self.agent_view_size, 3),
            dtype="uint8",
        )
        self.observation_space = spaces.Dict(
            {
                "image": image_observation_space,
                "direction": spaces.Discrete(4),
                "mission": spaces.MultiDiscrete(
                    [len(self.word_dict.keys())] * max_words_in_mission
                ),
            }
        )

    @staticmethod
    def get_minigrid_words():
        colors = ["red", "green", "blue", "yellow", "purple", "grey"]
        objects = [
            "unseen",
            "empty",
            "wall",
            "floor",
            "box",
            "key",
            "ball",
            "door",
            "goal",
            "agent",
            "lava",
        ]

        verbs = [
            "pick",
            "avoid",
            "get",
            "find",
            "put",
            "use",
            "open",
            "go",
            "fetch",
            "reach",
            "unlock",
            "traverse",
        ]

        extra_words = [
            "up",
            "the",
            "a",
            "at",
            ",",
            "square",
            "and",
            "then",
            "to",
            "of",
            "rooms",
            "near",
            "opening",
            "must",
            "you",
            "matching",
            "end",
            "hallway",
            "object",
            "from",
            "room",
        ]

        all_words = colors + objects + verbs + extra_words
        assert len(all_words) == len(set(all_words))
        return {word: i for i, word in enumerate(all_words)}

    def string_to_indices(self, string, offset=1):
        """
        Convert a string to a list of indices.
        """
        indices = []
        # adding space before and after commas
        string = string.replace(",", " , ")
        for word in string.split():
            if word in self.word_dict.keys():
                indices.append(self.word_dict[word] + offset)
            else:
                raise ValueError(f"Unknown word: {word}")
        return indices

    def observation(self, obs):
        obs["mission"] = self.string_to_indices(obs["mission"])
        assert len(obs["mission"]) < self.max_words_in_mission
        obs["mission"] += [0] * (self.max_words_in_mission - len(obs["mission"]))

        return obs

Better encode/embed MiniGrid State to speed up training in DT's.

I have two main ideas for this:

Implement a variation of the BOW encoding that is used by BabyAI but add position to avoid building a convnet.

Not discussed in the BabyAI paper itself, it seems like they actually used a much better tokenization scheme than I am using and this could plausibly be causing many of my problems.

class ImageBOWEmbedding(nn.Module):
   def __init__(self, max_value, embedding_dim):
       super().__init__()
       self.max_value = max_value
       self.embedding_dim = embedding_dim
       self.embedding = nn.Embedding(3 * max_value, embedding_dim)
       self.apply(initialize_parameters)

   def forward(self, inputs):
       offsets = torch.Tensor([0, self.max_value, 2 * self.max_value]).to(inputs.device)
       inputs = (inputs + offsets[None, :, None, None]).long()
       return self.embedding(inputs).sum(1).permute(0, 3, 1, 2)

This takes each position and it represents it as a unique embedding that is independent of position. For example, Key + yellow + close (object, state,color) -> 13 or something -> maps to a specific vector. They then pass this into a convolutional network (which I was hoping to avoid).

I can do better than this though in the context of my model. I can have, bear with me, 5 separate embedding matrices. Consider the current model, it's actually, 3 embedding matrices:

  • 1 state
  • 1 object
  • 1 color
  • 1 row position
  • 1 column position.

Then any given object,color,state at any given position starts of with a unique representation. We can then look at how weight regularization acts as a feature selector over these and how the embeddings of each evolve in response to the circuits which use them.

One complication here is that for the positional embeddings, it seems like I should use something like sinusoidal embeddings but I want to ensure they are orthogonal to the previous sinusoidal embeddings. I have some ideas about how to do this but I will google it anyway.

Eg: this from the gato paper.
Image

Write a Rollout Sampling Utility for PPO Agents and add features affect generated distribution.

#43 was insufficient to create an optimal behaviour agent with the current rollout data.

To falsify the hypothesis that our failure to generate good trajectories (meet the criteria of both solving the task and being reasonable RTG modulated, at all, calibration unimportant for the time being) is a function of the data generating distribution with the current rollout methodology and PPO agent, I would like upgrade our ability to sample from PPO agents.

This will include:

  1. Adding an option to, rather than randomly sample from the posterior distribution indicated by the logits, sample deterministically from the best action.
  2. Adding an option to sample with temperature. This will enable us to calibrate the right amount of randomness.
  3. Adding an option to sample with from anything except the maximal logit (unlike temperature, this will reliably produce anti-good trajectories rather than consistently good trajectories).

Furthermore, the utility for this should also provide some feedback on the sampled trajectories (for example, average RTG achieved, RTG distribution, trajectory length distribution etc).

Tasks:

  • Write the end-end utility (ie, wrap what we have currently in a runner)
  • Add each of the options above with tests
    • Deterministic Sampling
    • Temperature Sampling
    • Bottom-k sampling
  • Make it possible to select more than of these and specify the number of rollouts to generate with each.
  • Make it possible to view some example rollouts when doing so/also have some performance metrics.

Remove git lfs

While cloning or otherwise working with the repo I saw errors like these ones:

error: external filter 'git-lfs filter-process' failed. I assume it is related to this file fatal: trajectories/MiniGrid-Dynamic-Obstacles-8x8-v0bd60729d-dc0b-4294-9110-8d5f672aa82c.pkl:

Error downloading object: trajectories/MiniGrid-Dynamic-Obstacles-8x8-v0bd60729d-dc0b-4294-9110-8d5f672aa82c.pkl (8c06bc4): Smudge error: Error downloading trajectories/MiniGrid-Dynamic-Obstacles-8x8-v0bd60729d-dc0b-4294-9110-8d5f672aa82c.pkl

This error message is seemingly related to Git Large File Storage (LFS). The error message indicates that the git-lfs filter-process command failed. This error can occur if Git LFS is not installed or if the Git LFS filters are not configured correctly.

This solution appears to work for me:

Check if git-lfs is installed and install it if it is not.

git lfs install --skip-smudge
git clone <your-repo-url>
git lfs pull
git lfs install --force

The --skip-smudge option tells Git LFS to skip the smudge filter, which is responsible for downloading the large files. The git clone command clones the repository. The git lfs pull command downloads the large files. Finally, the --force option reinstates the smudge filter.

Maybe it would be helpful to add something on git lfs in the Readme or somewhere else.

Investigate the effect of Dropout / Stochastic Depth on Model training/interpretability

From Gato paper: "Regularization: We train with an AdamW weight decay parameter of 0.1. Additionally, we use stochastic depth (Huang et al., 2016) during pretraining, where each of the transformer sub-layers (i.e. each Multi-Head Attention and Dense Feedforward layer) is skipped with a probability of 0.1."

Gao Huang, Yu Sun, Zhuang Liu, Daniel Sedra, and Kilian Weinberger. Deep networks with stochastic depth. Preprint arXiv:1603.09382, 2016.

Stochastic depth seems plausibly super valuable to me via intutions. I should read that paper at some point - Joseph

Rewrite evaluate_dt_agent to be parallelized (make use of gym sync vector env)

I used SyncVectorEnv to speed up game simulations elsewhere but rushed through writing the eval code and therefore didn't use it there.

Reimplementing evaluate_dt_agent with parallelization would speed up both the decision transformer training script and the calibration evaluation script.

Reasonable steps might look something like:

  • Write a test which ensures that evaluate_dt_agent works
  • Rewrite the function with parallelization (get statistics for each episode correctly, flatten them)
  • See that the test passes
  • Regenerate a calibration curve for models/MiniGrid-Dynamic-Obstacles-8x8-v0/demo_model_one_hot_overnight.pt, which should look like the one in the post (more/less).

Write a utility for merging sampled rollouts into a single file

We can sample rollouts using sample from agents. However, it would be good to be able to merge trajectory datasets. In order to do this, it might be worth cleaning up the offline_dataset code and working out the right way to do this. A fast way would be to use torch utils to concat datasets but this break the visualizer we have for the dataset. Possible not worth doing anything else.

def sample_from_agents(agents, rollout_length=2000, trajectory_path=None, num_envs=1):
all_episode_lengths = []
all_episode_returns = []
# Sample rollouts from each agent
for i, agent in enumerate(agents):
memory = Memory(agent.envs, OnlineTrainConfig(
num_envs=num_envs), device=agent.device)
if trajectory_path:
trajectory_writer = TrajectoryWriter(
path=os.path.join(trajectory_path, f"rollouts_agent_{i}.gz"),
run_config=RunConfig(track=False),
environment_config=agent.environment_config,
online_config=OnlineTrainConfig(num_envs=num_envs),
model_config=agent.model_config
)
else:
trajectory_writer = None
agent.rollout(memory, rollout_length, agent.envs, trajectory_writer)
if trajectory_writer:
trajectory_writer.tag_terminated_trajectories()
trajectory_writer.write(upload_to_wandb=False)
# Process the episode lengths and returns
df = process_memory_vars_to_log(memory.vars_to_log)
all_episode_lengths.append(df['episode_length'])
all_episode_returns.append(df['episode_return'])
return all_episode_lengths, all_episode_returns

Investigate/Possibly Add Gated MLP Units to Transformer Models.

Based on my understanding of GLUs from reading these three papers:

It seems likely that such modification on our transformers will be valuable. Moreover, Neel told me that they are used in models including T5 (see here: https://huggingface.co/docs/transformers/model_doc/t5, they use gated gelu).

I think this shouldn't be too hard to add as a PR, and I could theoretically accept it myself but could talk to Neel first. At a high level we have:

  • 1. makes sure you can specific this in the config reasonably
  • 2. make sure you can implement it without any problems (check it works as expected). have tests.
  • 3. make sure you get hooks for the relevant layers (we will need to add an activation)
  • 4. check if the default d_mlp should change if you do this.

And then a nice optional thing to do is to make a tutorial talking about how to interpret it. I think this would be best done once we have loaded T5 and were doing interp on it. Or we can train a very small model that uses it and use that as a demo. I also want to to know what other people link about this.

Error while sampling from new trajectories generated by LSTM model

https://wandb.ai/arena-ldn/PPO-MiniGrid/artifacts/trajectory/4eb3c096-8836-4d0f-973a-67685b89d0f0.gz/12e1daaebfd9c806051f

python -m src.run_decision_transformer \
    --exp_name MiniGrid-MemoryS7FixedStart-v0 \
    --trajectory_path trajectories/4eb3c096-8836-4d0f-973a-67685b89d0f0.gz \
    --d_model 128 \
    --n_heads 2 \
    --d_mlp 256 \
    --n_layers 1 \
    --learning_rate 0.0001 \
    --batch_size 128 \
    --train_epochs 5 \
    --test_epochs 1 \
    --n_ctx 23 \
    --pct_traj 1 \
    --weight_decay 0.001 \
    --seed 1 \
    --wandb_project_name DecisionTransformerInterpretability \
    --test_frequency 1000 \
    --eval_frequency 1000 \
    --eval_episodes 10 \
    --initial_rtg -1 \
    --initial_rtg 0 \
    --initial_rtg 1 \
    --prob_go_from_end 0.1 \
    --eval_max_time_steps 1000 \
    --track True
  File "/Users/josephbloom/miniforge3/envs/decision_transformer_interpretability/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/josephbloom/miniforge3/envs/decision_transformer_interpretability/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/josephbloom/GithubRepositories/DecisionTransformerInterpretability/src/run_decision_transformer.py", line 55, in <module>
    run_decision_transformer(
  File "/Users/josephbloom/GithubRepositories/DecisionTransformerInterpretability/src/decision_transformer/runner.py", line 114, in run_decision_transformer
    model = train(
  File "/Users/josephbloom/GithubRepositories/DecisionTransformerInterpretability/src/decision_transformer/train.py", line 142, in train
    evaluate_dt_agent(
  File "/Users/josephbloom/GithubRepositories/DecisionTransformerInterpretability/src/decision_transformer/train.py", line 342, in evaluate_dt_agent
    new_obs, new_reward, terminated, truncated, info = env.step(action)
  File "/Users/josephbloom/miniforge3/envs/decision_transformer_interpretability/lib/python3.10/site-packages/gymnasium/vector/vector_env.py", line 203, in step
    return self.step_wait()
  File "/Users/josephbloom/miniforge3/envs/decision_transformer_interpretability/lib/python3.10/site-packages/gymnasium/vector/sync_vector_env.py", line 149, in step_wait
    ) = env.step(action)
  File "/Users/josephbloom/miniforge3/envs/decision_transformer_interpretability/lib/python3.10/site-packages/gymnasium/wrappers/record_video.py", line 155, in step
    ) = self.env.step(action)
  File "/Users/josephbloom/miniforge3/envs/decision_transformer_interpretability/lib/python3.10/site-packages/gymnasium/core.py", line 408, in step
    return self.env.step(action)
  File "/Users/josephbloom/miniforge3/envs/decision_transformer_interpretability/lib/python3.10/site-packages/gymnasium/wrappers/record_episode_statistics.py", line 89, in step
    ) = self.env.step(action)
  File "/Users/josephbloom/miniforge3/envs/decision_transformer_interpretability/lib/python3.10/site-packages/gymnasium/wrappers/order_enforcing.py", line 56, in step
    return self.env.step(action)
  File "/Users/josephbloom/miniforge3/envs/decision_transformer_interpretability/lib/python3.10/site-packages/gymnasium/wrappers/env_checker.py", line 49, in step
    return self.env.step(action)
  File "/Users/josephbloom/GithubRepositories/DecisionTransformerInterpretability/src/environments/memory.py", line 167, in step
    if action == Actions.pickup:
RuntimeError: Boolean value of Tensor with more than one value is ambiguous

Upgrade Collect Demonstrations Workflow

The collect demonstrations utility is responsible for collecting example trajectories from a trained agent (one of the 3 ppo agent architectures supported, only two work). It provides a few different sampling procedures for doing this such as basic (proportional to softmax), temperature sampling, bottomk and topk. These enable us to train decision transformers on a broader distribution of actions/observations encouraging more robust features to be learned and a better calibrated RTG - behavior relationship (if you do behavioral cloning, you are restricted to training on only good trajectories with the offline agent which doesn't lead to good features).

It's not super clear yet how useful this is, but there's very obvious next steps to improve this utility that seem like a good engineering practice if anyone wants to help. The major goals are:

  • Hook up wandb tracking (add an arg for track, then log metrics to the dashboard)
  • media to log: videos of each of the rollouts (same as ppo rollout out code somewhat). We want to see qualitatively what kinds of trajectories we're sampling.
  • Metric to log: reward/time to finish for different rollout configs, we want to know which config setting sample what kinds of trajectories/outcomes.

It's possible that is an opportunity for the code around uploading videos which is very janky can be improved.

Write much better tutorials for the repo.

I've explained how to use the repo in the docs, but this could be way better.

Some examples of pages what might be useful to add to the sphinx docs.

  1. A page explaining what Minigrid is, and where to look at each MiniGrid environment (maybe include a table of environments by row with columns for different models (dt, bc, and ppo-traj, ppo standard).
  2. A page explaining the difference between the online method (PPO) and offline methods, explaining why we need the trajectories for the latter.
  3. A page/diagram explaining how trajectories are generated with the ppo method used to train offline agents.
  4. MOST IMPORTANT: pages for stuff you think I won't have thought to explain since I wrote the repo.

MiniGrid: Finish the Maze Environments for MiniGrid

A couple of weeks ago I worked with Navpreet to start writing a Maze Environment for Minigrid which is mostly working. Finishing this PR and possibly adding a version that uses Kruskal's algorithm and not Prim's could be a really valuable contribution.

The exciting thing about these environments being part of Minigrid is that it would offer opportunities to study algorithmic distillation and precise, memory-based reasoning in transformers. I think doing this could be the start of a lot of cool work.

Find a third party service for hosting the streamlit app and making it way faster to use

I currently host a public version of the interpretability app with the streamlit community, which is very slow.

I would be willing to pay for a service to host the app, but I need to know that deployment will be very easy.

Checklist:

Bug: Trajectory Dataset contains pre-emptively truncated trajectories from where PPO get's cut off

This is a bug in the trajectorywriter/offline dataset where we end up truncating some trajectories when we finish online training and this leads to having โ€œshortโ€ truncated trajectories, which are bad for our data. It would be good to remove them. They are visible in the visualization of the reward over traj-lengths as spots on the x-axis but not at max-length.

A link to the method I use to ensure that these get labelled as truncated to avoid bugs:

def tag_terminated_trajectories(self):

qb5x9yhvt3htnlwjc5zs

Explore Improvements to DT Training Procedure

/Just wanted to have meta card to track progress on these things with links:

  • LayerNorm (I'll probably only try layernorm pre) (#52)
  • AdamW Optimizer
  • Adding a warmup stage with LambdaLR scheduler or cosine annealing
  • Implement gated MLP's (https://arxiv.org/pdf/2002.05202.pdf). Might need to be done in TransformerLens.
  • Make it possible to use GeLU not ReLU (try that out as well).
  • Better encode state. #61
  • Look into current init ranges for all the model components and consider proper init ranges
  • Look into where all the parameters are and consider how we can make a sparser model
  • Implement wandb sweeps for DT training (likely already exists a card for this so I should find it)
  • Implement masking rather than just having different tokens during padding. Might be important?

If we've implemented all of those and still no success with the memory env training, possibly try either much longer training runs, more variable sampling methods, or ask for advice (or go bug hunting).

Vectorize get trajectory minibatches method of memory class (useful for TrajPPO model)

I recently wrote a version of get_minibatches in the memory class of the ppo subpackage.

def get_trajectory_minibatches(self, timesteps: int, prob_go_from_end: float = 0.1) -> List[TrajectoryMinibatch]:
'''Return a list of trajectory minibatches, where each minibatch contains
experiences from a single trajectory.
Args:
- timesteps (int): the number of timesteps to include in each minibatch.
Returns:
- List[TrajectoryMinibatch]: a list of minibatches.
'''
obs, dones, actions, logprobs, values, rewards = [
t.stack(arr) for arr in zip(*self.experiences)]
next_values = t.cat([values[1:], self.next_value.unsqueeze(0)])
next_dones = t.cat([dones[1:], self.next_done.unsqueeze(0)])
# px.imshow(obs[:,1,:,:,0].transpose(-1,-2), animation_frame = 0, range_color = [0,10]).show()
# set last value of dones to 1
dones[-1] = t.ones(dones.shape[-1])
# hack for now.
# will cause problems if you only have one environment
if logprobs.shape[-1] == 1:
logprobs = logprobs.squeeze(-1)
# rearrange to flatten out the env dimension (2nd dimension)
obs = rearrange(obs, "T E ... -> (E T) ...")
dones = rearrange(dones, "T E -> (E T)")
next_dones = rearrange(next_dones, "T E -> (E T)")
actions = rearrange(actions, "T E ... -> (E T) ...")
logprobs = rearrange(logprobs, "T E -> (E T)")
values = rearrange(values, "T E -> (E T)")
next_values = rearrange(next_values, "T E -> (E T)")
rewards = rearrange(rewards, "T E -> (E T)")
# find the indices of the end of each trajectory
traj_end_idxs = (t.where(dones)[0] + 1).tolist()
# split these trajectories on the dones
traj_obs = t.tensor_split(obs, traj_end_idxs)
traj_actions = t.tensor_split(actions, traj_end_idxs)
traj_logprobs = t.tensor_split(logprobs, traj_end_idxs)
traj_values = t.tensor_split(values, traj_end_idxs)
traj_rewards = t.tensor_split(rewards, traj_end_idxs)
traj_dones = t.tensor_split(dones, traj_end_idxs)
traj_next_values = t.tensor_split(next_values, traj_end_idxs)
traj_next_dones = t.tensor_split(next_dones, traj_end_idxs)
# px.imshow(traj_obs[0][:,:,:,0].transpose(-1,-2), animation_frame = 0, range_color = [0,10]).show()
# so now we have lists of trajectories, what we want is to split each trajectory
# so for each trajectory, sample an index and go n_steps back from that.
# since we're encoding states and actions, we want to go context_length//2 back
# if that happens to go off the end, then we
minibatches = []
# remove trajectories of length 0
traj_obs = [traj for traj in traj_obs if len(traj) > 0]
n_trajectories = len(traj_obs)
trajectory_lengths = [len(traj) for traj in traj_obs]
for _ in range(self.args.num_minibatches):
minibatch_obs = []
minibatch_actions = []
minibatch_logprobs = []
minibatch_advantages = []
minibatch_values = []
minibatch_returns = []
minibatch_timesteps = []
minibatch_rewards = []
for _ in range(self.args.minibatch_size):
# randomly select a trajectory
traj_idx = np.random.randint(n_trajectories)
# randomly select an end index from the trajectory
# TODO later add a hyperparameter to oversample last step
traj_len = trajectory_lengths[traj_idx]
if traj_len <= timesteps:
end_idx = traj_len
start_idx = 0
else:
if prob_go_from_end is not None:
if random.random() < prob_go_from_end:
end_idx = traj_len
start_idx = end_idx - timesteps
else:
end_idx = np.random.randint(timesteps, traj_len)
start_idx = end_idx - timesteps
else:
end_idx = np.random.randint(timesteps, traj_len)
start_idx = end_idx - timesteps
# get the trajectory
current_traj_obs = traj_obs[traj_idx][start_idx:end_idx]
current_traj_actions = traj_actions[traj_idx][start_idx:end_idx]
current_traj_logprobs = traj_logprobs[traj_idx][start_idx:end_idx]
current_traj_values = traj_values[traj_idx][start_idx:end_idx]
current_traj_dones = traj_dones[traj_idx][start_idx:end_idx]
current_traj_rewards = traj_rewards[traj_idx][start_idx:end_idx]
current_traj_next_value = traj_next_values[traj_idx][end_idx - 1]
current_traj_next_done = traj_next_dones[traj_idx][end_idx - 1]
# make timesteps
current_traj_timesteps = t.arange(start_idx, end_idx)
# Compute the advantages and returns for this trajectory.
current_traj_advantages = self.compute_advantages(
current_traj_next_value,
current_traj_next_done,
current_traj_rewards,
current_traj_values,
current_traj_dones,
self.device,
self.args.gamma,
self.args.gae_lambda
)
current_traj_returns = current_traj_advantages + current_traj_values
# we need to pad current_traj_obs and current_traj_actions
current_traj_obs = pad_tensor(
current_traj_obs,
timesteps,
ignore_first_dim=False,
pad_token=0,
pad_left=True
)
current_traj_actions = pad_tensor(
current_traj_actions,
timesteps,
ignore_first_dim=False,
pad_token=0,
pad_left=True
)
current_traj_timesteps = pad_tensor(
current_traj_timesteps,
timesteps,
ignore_first_dim=False,
pad_token=0,
pad_left=True
)
# add to minibatch
minibatch_obs.append(current_traj_obs)
minibatch_actions.append(current_traj_actions)
minibatch_logprobs.append(current_traj_logprobs[-1])
minibatch_advantages.append(current_traj_advantages[-1])
minibatch_values.append(current_traj_values[-1])
minibatch_returns.append(current_traj_returns[-1])
minibatch_rewards.append(current_traj_rewards[-1])
minibatch_timesteps.append(current_traj_timesteps)
# stack the minibatch
minibatch_obs = t.stack(minibatch_obs)
minibatch_actions = t.stack(minibatch_actions)
# only take the last values of the logprob, advantage,
# value and return (relevant to the last step of each trajectory)
minibatch_logprobs = t.stack(minibatch_logprobs)
minibatch_advantages = t.stack(minibatch_advantages)
minibatch_values = t.stack(minibatch_values)
minibatch_returns = t.stack(minibatch_returns)
minibatch_timesteps = t.stack(minibatch_timesteps)
minibatch_rewards = t.stack(minibatch_rewards)
minibatches.append(TrajectoryMinibatch(
obs=minibatch_obs,
actions=minibatch_actions,
logprobs=minibatch_logprobs,
advantages=minibatch_advantages,
values=minibatch_values,
returns=minibatch_returns,
timesteps=minibatch_timesteps,
rewards=minibatch_rewards
))
return minibatches

TLDNR: This is important for sampling sections of trajectories which is necessary for online training of trajectory models as opposed to models which only respond to the latest observation. I have a few ideas for what to do here:

Keep the logic more or less the same, but vectorize it. It's way to serialized and it doesn't have to be. Obviously write lots of tests.

Create an Object-Vector Calculator in the Streamlit App

Similar to the Maze experiments done by the shard theory team, calculate a vector corresponding to the key/ball in the memory env (or make a more general tool) that can be used in the app to analyse object configured behaviors.

PM me if you are interested in working on this.

Set padded RTG in training data to be true RTG until masking is implemented correctly.

We aren't currently masking padding tokens which I'm not sure how much of a big deal it is. I'll reach out to ask someone for feedback, but in the meantime, I'm going to set the RTG in the padded training data to the true RTG because the padding RTG is 0 which I think might have confused the model.

This is hopefully going to be the different between a memory task solving DT completely solving the task and it's current status of solving it but not using RTG to modulate behavior for the balls (it works if the key is the goal strangely).

  • update offline training data to padd with true RTG
  • message someone about how essential padding is (I'm not sure why it isn't already implemented on Hooked Transformer class, maybe it is and I missed it)

Add lr scheduling options to the DT training code

As part of #53, I'm going to explore some techniques which may prove useful for improving the performance of DTs / BCs.

  • Add a wandb graph showing the lr of the scheduler over time (this will ensure that we can validate the methods we use
  • Then add a LR scheduler utility (I think we'll just make use of the huggingface transformer utilities?)
  • Validate each work and see if any affect converge speed of Dynamic Obstacles.
  • Can try out on memory env if desired.

Add checkpoints during Offline Training

Our PPO models are now stored with checkpoints but our offline trained models aren't. Creating some parity here would be good.

Please ensure:

  • It remains easy to load a model into the app
  • All of the required parameters to instantiate a model are saved with that model
  • Nice utilities facilitate saving/loading of checkpoints as is the case with the PPO models
  • Don't duplicate anything you don't have to with the PPO checkpoint set up.

Minigrid: Submit MultiEnv Environment PR so we don't need to maintain this code in DTI.

I wrote a class that mixes environments/hyperparameters called a Multienv which is kinda useful for training, however it should probably live in the Minigrid github repo.

If someone has the time to submit that as a PR and write whatever tests their maintainers require (probably not a huge amount, then this would be greatly appreciated.

Checklist:
[ ] make a PR to Minigrid with the Multienv
[ ] add tests to the PR
[ ] have the PR accepted
[ ] when the PR is accepted, update this repo to use the new minigrid version and remove the multienv code from this repo.

Enhancement: Beautify the Streamlit App

It's now possible to allow columns (up to one level of nesting) inside columns in streamlit apps thanks to streamlit/streamlit#5941. I think this creates an oppurtunity to make the streamlit interpretability app way more visually appealing. I don't have strong opinions now but basically suspect someone could have a lot of fun with this and it'll be really enjoyable later.

For reference, the current streamlit app uses a portrait view and a sidebar for selecting which analyses appear where. I think it'd be interesting to redesign this. Happy to do a "user interview" to describe how I use stuff and what I don't like about how I've currently implemented it.
Screenshot 2023-03-19 at 8 27 42 pm

Debug TransformerPPO model

I failed to get the TransformerPPO model working. I suspect this is because of bugs, but it could also be that transformers are inherently unstable.

If someone is interested in attempting to fix my implementation then I would greatly appreciate it. For the time being, I'm going to pivot to a recurrent architecture to get progress on the project going again.

Start here:

class TrajPPOAgent(PPOAgent):
def __init__(self,
envs: gym.vector.SyncVectorEnv,
environment_config: EnvironmentConfig,
transformer_model_config: TransformerModelConfig,
device: t.device = t.device("cpu")
):
'''
An agent for a Proximal Policy Optimization (PPO) algorithm.
Args:
- envs (gym.vector.SyncVectorEnv): the environment(s) to interact with.
- device (t.device): the device on which to run the agent.
- environment_config (EnvironmentConfig): the configuration for the environment.
- transformer_model_config (TransformerModelConfig): the configuration for the transformer model.
- device (t.device): the device on which to run the agent.
'''
super().__init__(envs=envs, device=device)
self.environment_config = environment_config
self.transformer_model_config = transformer_model_config
self.obs_shape = get_obs_shape(envs.single_observation_space)
self.num_obs = np.array(self.obs_shape).prod()
self.num_actions = envs.single_action_space.n
self.hidden_dim = transformer_model_config.d_model
self.critic = CriticTransfomer(
transformer_config=transformer_model_config,
environment_config=environment_config,
)
self.layer_init(self.critic.value_predictor, std=0.01)
self.actor = ActorTransformer(
transformer_config=transformer_model_config,
environment_config=environment_config,
)
self.layer_init(self.actor.action_predictor, std=0.01)
self.device = device
self.to(device)
def rollout(self,
memory: Memory,
num_steps: int,
envs: gym.vector.SyncVectorEnv,
trajectory_writer=None) -> None:
"""Performs the rollout phase of the PPO algorithm, collecting experience by interacting with the environment.
Args:
memory (Memory): The replay buffer to store the experiences.
num_steps (int): The number of steps to collect.
envs (gym.vector.SyncVectorEnv): The vectorized environment to interact with.
trajectory_writer (TrajectoryWriter, optional): The writer to
log the collected trajectories. Defaults to None.
"""
device = memory.device
obs = memory.next_obs
action = None # will be set before used
done = memory.next_done
truncated = memory.next_done # mem done represents done | truncated
context_window_size = self.actor.transformer_config.n_ctx
obs_timesteps = (context_window_size - 1) // 2 + 1 # (the current obs)
actions_timesteps = obs_timesteps - 1
action_pad_token = self.actor.environment_config.action_space.n
n_envs = envs.num_envs
if isinstance(device, str):
device = t.device(device)
cuda = device.type == "cuda"
obss = t.zeros((n_envs, obs_timesteps, *obs.shape[1:]), device=device)
acts = t.ones((n_envs, actions_timesteps, 1),
device=device).to(t.long) * action_pad_token
timesteps = t.zeros((n_envs, obs_timesteps, 1),
device=device).to(t.long)
obss[:, -1] = obs
for step in range(num_steps):
if len(memory.experiences) == 0:
with t.inference_mode():
logits = self.actor(obss[:, -1:], None, timesteps[:, -1:])
values = self.critic(obss[:, -1:], None, timesteps[:, -1:])
value = values[:, -1].squeeze(-1) # value is scalar
else:
# temporarily making this code worse, refactor soon.
if obs_timesteps - 1 == 0:
obss = obs.unsqueeze(1) # just add the current obs
acts = None
else:
# obss
obss = t.cat((obss, obs.unsqueeze(1)),
dim=1) # add current obs
obss = obss[:, -obs_timesteps:] # truncate
# acts
# add current action
acts = t.cat(
(acts, action.unsqueeze(1).unsqueeze(-1)), dim=1)
acts = acts[:, -actions_timesteps:] # truncate
# timesteps
# add current timestep
timesteps = t.cat(
(timesteps, timesteps[:, -1:] + 1), dim=1)
if timesteps.max() > self.environment_config.max_steps:
assert False
timesteps = timesteps[:, -obs_timesteps:] # truncate
# Generate the next set of new experiences (one for each env)
with t.inference_mode():
# Our actor generates logits over actions which we can then sample from
logits = self.actor(obss, acts, timesteps)
# Our critic generates a value function (which we use in the value loss, and to estimate advantages)
values = self.critic(obss, acts, timesteps)
values = values[:, -1].squeeze(-1) # value is scalar
# get the last state action prediction
probs = Categorical(logits=logits[:, -1])
action = probs.sample()
logprob = probs.log_prob(action)
next_obs, reward, next_done, next_truncated, info = envs.step(
action.cpu().numpy())
next_obs = memory.obs_preprocessor(next_obs)
reward = t.from_numpy(reward).to(device)
# in each case where an episode is done, we need to reset the context window
# this is done by setting the last obs to the current obs and the rest to 0
# all the actions are set to zero
# timesteps are also reset
next_done_or_truncated = next_done | next_truncated
for i, d in enumerate(next_done_or_truncated):
if d:
obss[i, -1] = obs[i]
obss[i, :-1] = 0
if acts is not None:
acts[i] = action_pad_token
timesteps[i] = 0
if trajectory_writer is not None:
obs_np = obs.detach().cpu().numpy() if cuda else obs.detach().numpy()
reward_np = reward.detach().cpu().numpy() if cuda else reward.detach().numpy()
action_np = action.detach().cpu().numpy() if cuda else action.detach().numpy()
trajectory_writer.accumulate_trajectory(
next_obs=obs_np,
reward=reward_np,
action=action_np,
done=next_done,
truncated=next_truncated,
info=info
)
# Store (s_t, d_t, a_t, logpi(a_t|s_t), v(s_t), r_t+1)
mem_done = (done.to(bool) | truncated.to(bool)).to(float)
memory.add(info, obs, mem_done, action, logprob, value, reward)
obs = t.from_numpy(next_obs).to(device)
done = t.from_numpy(next_done).to(device, dtype=t.float)
truncated = t.from_numpy(next_truncated).to(device, dtype=t.float)
# Store last (obs, done, value) tuple, since we need it to compute advantages
memory.next_obs = obs
memory.next_done = done
with t.inference_mode():
obss = t.cat((obss, obs.unsqueeze(1)), dim=1)
acts = t.cat((acts, action.unsqueeze(1).unsqueeze(-1)),
dim=1) if acts is not None else None
obss = obss[:, -obs_timesteps:]
actions = acts[:, -
actions_timesteps:] if acts is not None else None
timesteps = timesteps[:, -obs_timesteps:]
values = self.critic(obss, actions, timesteps)
memory.next_value = values[:, -1].squeeze(-1)
def learn(self,
memory: Memory,
args: OnlineTrainConfig,
optimizer: optim.Optimizer,
scheduler: PPOScheduler,
track: bool) -> None:
"""Performs the learning phase of the PPO algorithm, updating the agent's parameters
using the collected experience.
Args:
memory (Memory): The replay buffer containing the collected experiences.
args (OnlineTrainConfig): The configuration for the training.
optimizer (optim.Optimizer): The optimizer to update the agent's parameters.
scheduler (PPOScheduler): The scheduler attached to the optimizer.
track (bool): Whether to track the training progress.
"""
for _ in range(args.update_epochs):
n_timesteps = (self.actor.transformer_config.n_ctx - 1) // 2 + 1
minibatches = memory.get_trajectory_minibatches(
n_timesteps, args.prob_go_from_end)
# Compute loss on each minibatch, and step the optimizer
for mb in minibatches:
obs = mb.obs
actions = mb.actions[:, :-1].unsqueeze(-1).to(
int) if mb.obs.shape[1] > 1 else None
timesteps = mb.timesteps.unsqueeze(-1).to(int)
logits = self.actor(obs, actions, timesteps)
values = self.critic(obs, actions, timesteps)
values = values[:, -1].squeeze(-1)
probs = Categorical(logits=logits[:, -1])
clipped_surrogate_objective = calc_clipped_surrogate_objective(
probs=probs,
mb_action=mb.actions[:, -1].squeeze(-1),
mb_advantages=mb.advantages,
mb_logprobs=mb.logprobs,
clip_coef=args.clip_coef)
value_loss = calc_value_function_loss(
values, mb.returns, args.vf_coef)
entropy_bonus = calc_entropy_bonus(probs, args.ent_coef)
total_objective_function = clipped_surrogate_objective - value_loss + entropy_bonus
optimizer.zero_grad()
total_objective_function.backward()
nn.utils.clip_grad_norm_(self.parameters(), args.max_grad_norm)
optimizer.step()
# Step the scheduler
scheduler.step()
# Get debug variables, for just the most recent minibatch (otherwise there's too much logging!)
if track:
with t.inference_mode():
newlogprob = probs.log_prob(mb.actions.unsqueeze(-1))
logratio = newlogprob - mb.logprobs
ratio = logratio.exp()
approx_kl = (ratio - 1 - logratio).mean().item()
clipfracs = [
((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
memory.add_vars_to_log(
learning_rate=optimizer.param_groups[0]["lr"],
avg_value=values.mean().item(),
value_loss=value_loss.item(),
clipped_surrogate_objective=clipped_surrogate_objective.item(),
entropy=entropy_bonus.item(),
approx_kl=approx_kl,
clipfrac=np.mean(clipfracs)
)

Big tasks are:

  • Get the Traj PPO agent to pass the memory test.
  • passing some of the skipped acceptance tests might be useful feedback on the way.

Store Model Checkpoints during PPO Training

In order to have enough training data to give decision transformers a good distribution over good/bad trajectories. It would good to be able to sample from an arbitrary quality agent (which means we need to store these agents).

As an initial step, storing model checkpoints during training of PPO agents would be useful. Storing these on wandb seems like a good first step. (there are examples in the code of the uploaded artifact).

if trajectory_writer is not None:
trajectory_writer.tag_terminated_trajectories()
trajectory_writer.write(upload_to_wandb=run_config.track)

def write(self, upload_to_wandb: bool = False):
data = {
'observations': np.array(self.observations, dtype=np.float64),
'actions': np.array(self.actions, dtype=np.int64),
'rewards': np.array(self.rewards, dtype=np.float64),
'dones': np.array(self.dones, dtype=bool),
'truncated': np.array(self.truncated, dtype=bool),
'infos': np.array(self.infos, dtype=object)
}
if dataclasses.is_dataclass(self.args):
metadata = {
"args": asdict(self.args), # Args such as ppo args
"time": time.time() # Time of writing
}
else:
metadata = {
"args": self.args, # Args such as ppo args
"time": time.time() # Time of writing
}
if not os.path.exists(os.path.dirname(self.path)):
os.makedirs(os.path.dirname(self.path))
# use lzma to compress the file
if self.path.endswith(".xz"):
print(f"Writing to {self.path}, using lzma compression")
with lzma.open(self.path, 'wb') as f:
pickle.dump({
'data': data,
'metadata': metadata
}, f)
elif self.path.endswith(".gz"):
print(f"Writing to {self.path}, using gzip compression")
with gzip.open(self.path, 'wb') as f:
pickle.dump({
'data': data,
'metadata': metadata
}, f)
else:
print(f"Writing to {self.path}")
with open(self.path, 'wb') as f:
pickle.dump({
'data': data,
'metadata': metadata
}, f)
if upload_to_wandb:
artifact = wandb.Artifact(
self.path.split("/")[-1], type="trajectory")
artifact.add_file(self.path)
wandb.log_artifact(artifact)
print(f"Trajectory written to {self.path}")

Mega Card: Improve Analysis App in various ways to facilitate better interpretability analysis of the new models

Analysis features

Static

Composition

  • Make composition maps
  • Replace composition scores with strip plots?
  • Create a meta-composition score. Something that measures total influence?
  • How do we check for composition between MLP_in and W_out? (seems expensive?, maybe tie to very specific hypotheses)

Dynamic

Logit Lens

  • By Layer
  • By Layer accumulated
  • By Head

Attention Maps:

  • Make it easier to export a nice visualization of the attention map (cv is actually not great for that).
  • Make it possible to calculate the rank(k) approximation to the attention map.

Causal

Activation Patching (features)

  • Set up component
  • Set up RTG Metric
  • Residual stream patching.
  • Patching via Attn and MLP
  • Head All Pos Patching
  • Head Specific Pos Patching (do later)
  • Head All Pos by Component
  • MLP at different Positions
  • Show counterfactual attention map (ie: show difference in attention given intervention)
  • Show what the logit diff is for each metric score.
    Activation Patching (token variations):
  • Action (fairly easy)
  • Key/Ball (important!)
  • Timestep (also fairly easy)

RTG Scan

  • Switch to using t-lens for decomp
  • Provide more than one level of decomp
  • Add a clustergram to show heads which mediate a similar relationship between RTG and logits/logit diff

Congruence -> If features aren't in superposition, what effect do they have on the predictions?

  • - Pos
  • - Time
  • - W_in
  • - W_Out
  • - MLP Out

Renew old features:

  • QK circuit visualizations for action and RTG embeddings

SVD Decomp / Explore ways to use dimensionality reduction to quickly understand what heads are doing.

Cache Characterization?

  • Plot L2 norm of residual streams (along with mean and std)

Advanced

Implement Path Patching

  • Understand Callum's code.

Implement AVEC

  • Reread post to see if we can find.

Several things I feel are missing which are required for exploratory analysis to be more complete:

  • visualise dot product of time embeddings with each other
  • visualise dot product of positional embeddings with each other
  • Use Jay's head type analysis but write specific patterns for attending to RTG, attending to positive RTG, attending to states, and attending to actions.

Several things I feel will be required for falsifying predictions of how the model is working:

  • implement a variant of path patching for DTs either in a notebook or as part of the app.
  • CaSc, not sure how feasible this is but it has always been the goal.

Improvement: Rewrite make_env to take an environment config file

I think the make_env function should take in the environment_config class we're using everywhere.

A link the the make_env function:

def make_env(
env_id: str,
seed: int,
idx: int,
capture_video: bool,
run_name: str,
render_mode="rgb_array",
max_steps=100,
fully_observed=False,
flat_one_hot=False,
agent_view_size=7,
video_frequency=50
):

A link to the config class it should take:

@dataclass
class EnvironmentConfig():
'''
Configuration class for the environment.
'''
env_id: str = 'MiniGrid-Dynamic-Obstacles-8x8-v0'
one_hot_obs: bool = False
img_obs: bool = False
fully_observed: bool = False
max_steps: int = 1000
seed: int = 1
view_size: int = 7
capture_video: bool = False
video_dir: str = 'videos'
render_mode: str = 'rgb_array'
action_space: None = None
observation_space: None = None
device: str = 'cpu'
def __post_init__(self):
env = gym.make(self.env_id)
if self.env_id.startswith('MiniGrid'):
if self.fully_observed:
env = FullyObsWrapper(env)
elif self.one_hot_obs:
env = OneHotPartialObsWrapper(env)
elif self.img_obs:
env = RGBImgPartialObsWrapper(env)
if self.view_size != 7:
env = ViewSizeWrapper(env, self.view_size)
self.action_space = self.action_space or env.action_space
self.observation_space = self.observation_space or env.observation_space

Be sure to make sure the tests all pass! I think this will involve touching a lot of different pieces of code but it's still conceptually very simple.

Implement IBAC/SNI and measure the effect on model interpretability

Selective Noise Injection (SNI) and Information Bottleneck Actor-Critic (IBAC) make models better at generalising (including in at least one MiniGrid environment). It seems like a fun hack-dayish kind of effort to test this out.

This would currently be bottlenecked by the TrajectoryPPO class possibly not being perfectly working yet but I'm putting this here in case someone is ambitious and wanted to give it a shot.

References:

I consider this somewhat similar spiritually to Engineering Monosemanticity in Toy Models although I have no idea if this is exactly true.

Train a model using layer norm pre to see if this helps formation of calibrated, performant memory env agents.

Reading papers suggests this should help. I'm going to try it. It's already implemented except maybe in the app so I will:

  • Set layer-norm to LNPre
  • Training a few different models with this config. (the memory env and maybe one of our original models.
  • Observing any differences in training metrics
  • If the model turns out to be more performant ensure app compatibility now, otherwise make a card for it.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.