weipu-zhang / STORM

40 stars 10 forks source link

Implementation of STORM: Efficient Stochastic Transformer based World Models for Reinforcement Learning

PWC

Paper & OpenReview, you may find some useful discussion there.

This repo contains an implementation of STORM.

Following the Training and Evaluating Instructions to reproduce the main results presented in our paper. One may also find Additional Useful Information useful when debugging and observing intermediate results. To reproduce the speed metrics mentioned in the paper, please see Reproducing Speed Metrics.

Training and Evaluating Instructions

  1. Install the necessary dependencies. Note that we conducted our experiments using python 3.10.

    pip install -r requirements.txt

    Installing AutoROM.accept-rom-license may take several minutes.

  2. Train the agent.

    chmod +x train.sh
    ./train.sh

    The train.sh file controls the environment and the running name of a training process.

    env_name=MsPacman
    python -u train.py \
        -n "${env_name}-life_done-wm_2L512D8H-100k-seed1" \
        -seed 1 \
        -config_path "config_files/STORM.yaml" \
        -env_name "ALE/${env_name}-v5" \
        -trajectory_path "trajectory/${env_name}.pkl"
    • The env_name on the first line can be any Atari game, which can be found here.

    • -n option is the name for the tensorboard logger and checkpoint folder. You can change it to your preference, but we recommend keeping the environment's name first. The tensorboard logging folder is runs, and the checkpoint folder is ckpt.

    • The -seed parameter controls the running seed during the training. We evaluated our method using 5 seeds and report the mean return in Table 1.

    • The -config_path points to a YAML file that controls the model's hyperparameters. The configuration in config_files/STORM.yaml is the same as in our paper.

    • The -trajectory_path is only useful when the option UseDemonstration in the YAML file is set to True (by default it's False). This corresponds to the ablation studies in Section 5.3. We provide the pre-collected trajectories in the D_TRAJ.7z file, and you need to decompress it for use.

  3. Evaluate the agent. The evaluation results will be presented in a CSV file located in the eval_result folder.

    chmod +x eval.sh
    ./eval.sh

    The eval.sh file controls the environment and the running name when testing an agent.

    env_name=MsPacman
    python -u eval.py \
        -env_name "ALE/${env_name}-v5" \
        -run_name "${env_name}-life_done-wm_2L512D8H-100k-seed1"\
        -config_path "config_files/STORM.yaml" 

    The -run_name option is the same as the -n option in train.sh. It should be kept the same as in the training script.

Additional Useful Information

You can use Tensorboard to visualize the training curve and the imagination videos:

 chmod +x TensorBoard.sh
 ./TensorBoard.sh

Reproducing Speed Metrics

To reproduce the speed metrics mentioned in the paper, please consider the following:

Troubleshooting

Mixed precision on other devices

Windows and WSL

We've recently observed if one clones the repo from Powershell and then calls train.sh under WSL shell, then it may throw an error related to arg parse. This may be due to invisible newlines in the files somehow generated when cloning with git. The solution is to download the zip or clone directly inside WSL.

Code references

We've referenced several other projects during the development of this code:

Bibtex

@inproceedings{
    zhang2023storm,
    title={{STORM}: Efficient Stochastic Transformer based World Models for Reinforcement Learning},
    author={Weipu Zhang and Gang Wang and Jian Sun and Yetian Yuan and Gao Huang},
    booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
    year={2023},
    url={https://openreview.net/forum?id=WxnrX42rnS}
}