jbloomAus / DecisionTransformerInterpretability

Interpreting how transformers simulate agents performing RL tasks
https://jbloomaus-decisiontransformerinterpretability-app-4edcnc.streamlit.app/
MIT License
73 stars 16 forks source link
mechanistic-interpretability reinforcement-learning

Decision Transformer Interpretability

build Code style: black

Docs Here App Here

This project is designed to facilitate mechanistic interpretability of decision transformers as well as RL agents using transformer architectures.

This is achieved by:

Future work will include:

Write Up

You can find an initial technical report for this project here.

Package Overview

The package contains several important components:

Other notable files/folders:

Example Results

We've successfully trained a decision transformer on several games including DoorKey and Dynamic Obstacles.

Calibration Plot MiniGrid-Dynamic-Obstacles-8x8-v0, after 6000 batch, episode length 14, RTG 1.0, reward 0.955

I highly recommend playing with the streamlit app if you are interested in this project. It relies heavily on an understanding of the Mathematical Framework for Transformer Circuits.

Running the scripts

Example bash scripts are provided in the scripts folder. They make use of argparse interfaces in the package.

Training a PPO agent

If you set 'track' to true, a weights and biases dashboard will be generated. A trajectories pickle file will be generated in the trajectories folder. This file can be used to train a decision tranformer.

python -m src.run_ppo --exp_name "Test" \
    --seed 1 \
    --cuda \
    --track \
    --wandb_project_name "PPO-MiniGrid" \
    --env_id "MiniGrid-DoorKey-8x8-v0" \
    --view_size 5 \
    --total_timesteps 350000 \
    --learning_rate 0.00025 \
    --num_envs 8 \
    --num_steps 128 \
    --num_minibatches 4 \
    --update_epochs 4 \
    --clip_coef 0.2 \
    --ent_coef 0.01 \
    --vf_coef 0.5 \
    --max_steps 1000 \
    --one_hot_obs

Training a decision transformer

Targeting the trajectories file and setting the model architecture details and hyperparameters, you can run the decision transformer training script.

python -m src.run_decision_transformer \
    --exp_name MiniGrid-Dynamic-Obstacles-8x8-v0-Refactor \
    --trajectory_path trajectories/MiniGrid-Dynamic-Obstacles-8x8-v0bd60729d-dc0b-4294-9110-8d5f672aa82c.pkl \
    --d_model 128 \
    --n_heads 2 \
    --d_mlp 256 \
    --n_layers 1 \
    --learning_rate 0.0001 \
    --batch_size 128 \
    --train_epochs 5000 \
    --test_epochs 10 \
    --n_ctx 3 \
    --pct_traj 1 \
    --weight_decay 0.001 \
    --seed 1 \
    --wandb_project_name DecisionTransformerInterpretability-Dev \
    --test_frequency 1000 \
    --eval_frequency 1000 \
    --eval_episodes 10 \
    --initial_rtg -1 \
    --initial_rtg 0 \
    --initial_rtg 1 \
    --prob_go_from_end 0.1 \
    --eval_max_time_steps 1000 \
    --track True

Note, if you want the training data from the blog post, you can download it like so

cd trajectories
gdown 1UBMuhRrM3aYDdHeJBFdTn1RzXDrCL_sr

Running the Streamlit app

To run the Streamlit app:

streamlit run app.py

To run the Streamlit app on Docker, see the Development section.

Setting up the environment

I haven't been too careful about this yet. Using python 3.9.15 with the requirements.txt file. We're using the V2 branch of transformer lens and Minigrid 2.1.0.

conda env create --name decision_transformer_interpretability python=3.9.15
conda activate decision_transformer_interpretability
pip install -r requirements.txt

The docker file should work and we can make use of it more when the project is further ahead/if we are alternativing developers frequently and have any differential behavior.

./scripts/build_docker.sh
./scripts/run_docker.sh

Then you can ssh into the docker and a good ide will bring credentials etc.

Development

Docker

If you're having trouble making the environment work, I recommend Docker. There's a dockerfile in the main folder - it takes a few minutes the first time, and 10-15 seconds for me when only changing code. If adding requirements it may take a bit longer. I (Jay) use Ubuntu through WSL and Docker Desktop, and it worked pretty easily for me.

To run it, first navigate to your project directory, then:

docker build -t IMAGE_NAME .
docker run -d -it -v $(pwd):/app --name CONTAINER_NAME IMAGE_NAME bash

To reset the container (e.g, you've changed the code, and want to rerun your tests), use:

docker stop CONTAINER_NAME
docker rm CONTAINER_NAME
docker rmi IMAGE_NAME
docker build -t IMAGE_NAME .
docker run -p 8501:8501 -d -it -v $(pwd):/app --name CONTAINER_NAME IMAGE_NAME bash

I recommend setting this all up as a batch command so you can do it easily for a quick iteration time.

Finally, to run a command, use:

docker exec CONTAINER_NAME COMMAND

For instance, to run unit tests, you would use docker exec CONTAINER_NAME pytest tests/unit.

To run Streamlit on your local browser, you can use the following command:

docker exec CONTAINER_NAME streamlit run app.py --server.port=8501

Tests:

Ensure that the run_tests.sh script is executable:

chmod a+x ./scripts/run_tests.sh

Run the tests. Note: the end to end tests are excluded from the run_test.sh script since they take a while to run. They make wandb dashboards are are useful for debugging but they are not necessary for development.

To run end-to-end tests, you can use the command 'pytest -v --cov=src/ --cov-report=term-missing'. If the trajectories file 'MiniGrid-Dynamic-Obstacles-8x8-v0bd60729d-dc0b-4294-9110-8d5f672aa82c.pkl' is not found in the tests, the 'gdown' command has failed to download it. In that case, download it manually or run 'conda install -c conda-forge gdown' and try again.

./scripts/run_tests.sh

You should see something like this after the tests run. This is the coverage report. Ideally this is 100% but we're not there yet. Furthermore, it will be 100% long before we have enough tests. But if it's 100% and we have performant code with agents training and stuff otherwise working, that's pretty good.


---------- coverage: platform darwin, python 3.9.15-final-0 ----------
Name                                Stmts   Miss  Cover   Missing
-----------------------------------------------------------------
src/__init__.py                         0      0   100%
src/decision_transformer.py           132      8    94%   39, 145, 151, 156-157, 221, 246, 249
src/ppo.py                             20     20     0%   2-28
src/ppo/__init__.py                     0      0   100%
src/ppo/agent.py                      109     10    91%   41, 45, 112, 151-157
src/ppo/compute_adv_vectorized.py      30     30     0%   1-65
src/ppo/memory.py                      88     11    88%   61-64, 119-123, 147-148
src/ppo/my_probe_envs.py               99      9    91%   38, 42-44, 74, 99, 108, 137, 168
src/ppo/train.py                       69      6    91%   58, 74, 94, 98, 109, 113
src/ppo/utils.py                      146     54    63%   41-42, 61-63, 69, 75, 92-96, 110-115, 177-206, 217-235
src/utils.py                           40     17    58%   33-38, 42-65, 73, 76-79
src/visualization.py                   25     25     0%   1-34
-----------------------------------------------------------------
TOTAL                                 758    190    75%

Next Steps

Relevant Projects: