koulanurag / muzero-pytorch

Pytorch Implementation of MuZero
MIT License
341 stars 56 forks source link
deep-reinforcement-learning model-based-rl planning tree-search

muzero-pytorch

Pytorch Implementation of MuZero : "Mastering Atari , Go, Chess and Shogi by Planning with a Learned Model" based on pseudo-code provided by the authors

Note: This implementation has just been tested on CartPole-v1 and would required modifications(in config folder) for other environments

Installation

Usage:

Required Arguments Description
--env Name of the environment
--case {atari,classic_control,box2d} It's used for switching between different domains(default: None)
--opr {train,test} select the operation to be performed
Optional Arguments Description
--value_loss_coeff Scale for value loss (default: None)
--revisit_policy_search_rate Rate at which target policy is re-estimated (default:None)( only valid if --use_target_model is enabled)
--use_priority Uses priority for data sampling in replay buffer. Also, priority for new data is calculated based on loss (default: False)
--use_max_priority Forces max priority assignment for new incoming data in replay buffer (only valid if --use_priority is enabled) (default: False)
--use_target_model Use target model for bootstrap value estimation (default: False)
--result_dir Directory Path to store results (defaut: current working directory)
--no_cuda no cuda usage (default: False)
--no_mps no mps (Metal Performance Shaders) usage (default: False)
--debug If enables, logs additional values (default:False)
--render Renders the environment (default: False)
--force Overrides past results (default: False)
--seed seed (default: 0)
--num_actors Number of actors running concurrently (default: 32)
--test_episodes Evaluation episode count (default: 10)
--use_wandb Logs console and tensorboard data on wandb (default: False)

Note: default: None => Values are loaded from the corresponding config

Training

CartPole-v1