jannerm / mbpo

Code for the paper "When to Trust Your Model: Model-Based Policy Optimization"
https://jannerm.github.io/mbpo-www/
MIT License
478 stars 83 forks source link

Algorithm Architecture and Pytorch Implementation #25

Open RamiRibat opened 3 years ago

RamiRibat commented 3 years ago

Hi, This is really a nice work,

I've faced some issues related to TensorFlow and CUDA, and I'm not that good with TensorFlow, I'm a Pytorch guy.

So I've decided to make a Pytorch implementation for MBPO, and I'm trying to understand your code..

From my understanding: Taking AntTruncatedObs-v2 as a working example,

Pytorch Pceucode:

Total epochs = 1000 Epoch steps = 1000 Exploration epochs = 10

01. Initialize networks [Model, SAC]
02. Initialize training w/ [10 Exploration epochs (random) = 10 x 1000 environmnet steps]
03. For n in [Total epochs - Exploration epochs = 990 Epochs]:
04.    For i in [ 1000 Epoch Steps]:
05.        If i % [250 Model training freq] == 0:
06.            For g in [How many Model Gradient Steps???]:
07.                Sample a [256 size batch] from Env_pool
08.                Train the Model network
09.            Sample a [100k size batch] from Env_pool
10.            Set rollout_length
11.            Reallocate Model_pool [???]
12.            Rollout Model for rollout_length, and Add rollouts to Model_pool
13.        Sample an [action a] from the policy, Take Env step, and Add to Env_pool
14.        For g in [20 SAC Gradient Steps]:
15.            Sample a [256 size batch] from [05% Env_pool, 95% Model_pool]
16.            Train the Actor-Critic networks
17.    Evaluate the policy

Is that right?

My questions are about lines 06 & 11:

06: You're using some real time period to train the model.. in terms of gradients steps, How many steps they're? 11: When you reallocate the Model_pool, you set the [Model_pool size] to the number of [model steps per epoch], But.. Isn't that a really huge training set for SAC updates? Are you disgarding all Model steps from previous epochs?

Sorry for this very big issue..

Best wishes and kind regards.


Rami Ahmed

RamiRibat commented 3 years ago

Also, regarding line 11, for _how long should Modelpool expand? Bc it occupies the GPU's memory as it grows.

jiangsy commented 3 years ago

For line 6, I think it stops training until validation loss converges. I have implemented a pytorch version myself (https://github.com/jiangsy/mbpo_pytorch/tree/master/mbpo_pytorch), and you may view it as a reference (there are some still gaps in performance but may still provide some help).

RamiRibat commented 3 years ago

Thank you very much.. @jiangsy

Xingyu-Lin commented 3 years ago

Here is a pytorch implementation that achieves the same performance on walker and hopper: https://github.com/Xingyu-Lin/mbpo_pytorch. Other tasks not tested.