Transformers are Sample-Efficient World Models
Vincent Micheli*, Eloi Alonso*, François Fleuret
* Denotes equal contribution
tl;dr
If you find this code or paper useful, please use the following reference:
@inproceedings{
iris2023,
title={Transformers are Sample-Efficient World Models},
author={Vincent Micheli and Eloi Alonso and Fran{\c{c}}ois Fleuret},
booktitle={The Eleventh International Conference on Learning Representations },
year={2023},
url={https://openreview.net/forum?id=vhFu1Acb0xb}
}
pip install -r requirements.txt
python src/main.py env.train.id=BreakoutNoFrameskip-v4 common.device=cuda:0 wandb.mode=online
By default, the logs are synced to weights & biases, set wandb.mode=disabled
to turn it off.
config/
, the main configuration file is config/trainer.yaml
.Each new run is located at outputs/YYYY-MM-DD/hh-mm-ss/
. This folder is structured as:
outputs/YYYY-MM-DD/hh-mm-ss/
│
└─── checkpoints
│ │ last.pt
| | optimizer.pt
| | ...
│ │
│ └─── dataset
│ │ 0.pt
│ │ 1.pt
│ │ ...
│
└─── config
│ | trainer.yaml
|
└─── media
│ │
│ └─── episodes
│ | │ ...
│ │
│ └─── reconstructions
│ | │ ...
│
└─── scripts
| | eval.py
│ │ play.sh
│ │ resume.sh
| | ...
|
└─── src
| | ...
|
└─── wandb
| ...
checkpoints
: contains the last checkpoint of the model, its optimizer and the dataset.media
:
episodes
: contains train / test / imagination episodes for visualization purposes.reconstructions
: contains original frames alongside their reconstructions with the autoencoder.scripts
: from the run folder, you can use the following three scripts.
eval.py
: Launch python ./scripts/eval.py
to evaluate the run.resume.sh
: Launch ./scripts/resume.sh
to resume a training that crashed.play.sh
: Tool to visualize some interesting aspects of the run../scripts/play.sh
to watch the agent play live in the environment. If you add the flag -r
, the left panel displays the original frame, the center panel displays the same frame downscaled to the input resolution of the discrete autoencoder, and the right panel shows the output of the autoencoder (what the agent actually sees)../scripts/play.sh -w
to unroll live trajectories with your keyboard inputs (i.e. to play in the world model). Note that for faster interaction, the memory of the Transformer is flushed every 20 frames../scripts/play.sh -a
to watch the agent play live in the world model. Note that for faster interaction, the memory of the Transformer is flushed every 20 frames../scripts/play.sh -e
to visualize the episodes contained in media/episodes
.-h
to display a header with additional information.,
' to start and stop recording. The corresponding segment is saved in media/recordings
in mp4 and numpy formats.-s
to enter 'save mode', where the user is prompted to save trajectories upon completion.The folder results/data/
contains raw scores (for each game, and for each training run) for IRIS and the baselines.
Use the notebook results/results_iris.ipynb
to reproduce the figures from the paper.
Pretrained models are available here.
To start a training run from one of these checkpoints, in the section initialization
of config/trainer.yaml
, set path_to_checkpoint
to the corresponding path, and load_tokenizer
, load_world_model
, and load_actor_critic
to True
.
To visualize one of these checkpoints, set train.id
to the corresponding game in config/env/default.yaml
, create a checkpoints
directory and copy the checkpoint to checkpoints/last.pt
. You can then visualize the agent with ./scripts/play.sh
as described above.