ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.34k stars 5.64k forks source link

[rllib] PyTorch PPO / APPO #3365

Closed ericl closed 4 years ago

ericl commented 5 years ago

Describe the problem

Currently we only have a simple pytorch example for A3C. It should be possible to "port" a bunch of pytorch algorithms onto the TorchPolicyGraph class and get them basically for free.

We'll also have to improve pytorch support in some ways (e.g., dict space support, lstm).

ericl commented 5 years ago

cc @michaeltu1 @megankawakami

nilsjohanbjorck commented 5 years ago

Dear Eric,

I might be interested in taking a stab at implementing either APE-X or IMPALA in pytorch. For APE-X I've looked into rllib.dqn.dqn_policy_graph.py to see what interface it provides. Based upon this class, it seems like I'd need to subclass TorchPolicyGraph and then implement the following functions which provide the main parts of the algorithm:

gradients(self, optimizer)
postprocess_trajectory(self,sample_batch,other_agent_batches=None,episode=None)
compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights)

The functions below are also provided in dqn_policy_graph.py, although their implementation seems straightforward

optimizer(self)
extra_compute_action_feed_dict(self)
extra_compute_grad_fetches(self)
update_target(self)
set_epsilon(self, epsilon)
get_state(self)
set_state(self, state)

Is there anything else that would be needed?

ericl commented 5 years ago

Hey @nilsjohanbjorck,

Looking at TorchPolicyGraph, it seems you'd also have to replace the compute_actions() method, which currently assumes that the model returns are logits and does a softmax instead of the argmax / eps. greedy exploration for Q learning: https://github.com/ray-project/ray/blob/9266909c9dbf2e9abc7ec1e23e2a6a26185128b0/python/ray/rllib/evaluation/torch_policy_graph.py#L74

Also, a bunch of features like returning td_error from compute_gradients() are just missing in TorchPolicyGraph (you can see here the extra returns are hardcoded to {} https://github.com/ray-project/ray/blob/9266909c9dbf2e9abc7ec1e23e2a6a26185128b0/python/ray/rllib/evaluation/torch_policy_graph.py#L89) so you'd need to get that to work too.

nilsjohanbjorck commented 5 years ago

@ericl,

Sorry for the late reply, I'm at a conference. Overriding the compute_actions() method of TorchPolicyGraph seems straightforward. What is needed from the compute_gradients() method seems to be describer here

https://github.com/ray-project/ray/blob/9266909c9dbf2e9abc7ec1e23e2a6a26185128b0/python/ray/rllib/optimizers/async_replay_optimizer.py#L39-L41

The class DQNPolicyGraph overrides the method extra_compute_grad_fetches(self) which provides the td-information. That function and extra_compute_action_feed_dict(self)seems to be unique to the class TFPolicyGraph and I assume these wouldn't need to be implemented in a pytorch version. Apart from providing td-information in compute_gradients(self, postprocessed_batch), are there any additional hard-coded {}s that would need to be filled out? It seems to not be the case to me, but I'm not that familiar with the code-base.

ericl commented 5 years ago

Hm yeah, that might be it. I would just try to see what makes sense to implement DQN properly, and not worry too much about what the TF version is doing. Then making the change to get it working in Ape-X should be pretty straightforward.

On Fri, Dec 7, 2018, 8:03 PM nilsjohanbjorck notifications@github.com wrote:

@ericl https://github.com/ericl,

Sorry for the late reply, I'm at a conference. Overriding the compute_actions() method of TorchPolicyGraph seems straightforward. What is needed from the compute_gradients() method seems to be describer here

https://github.com/ray-project/ray/blob/9266909c9dbf2e9abc7ec1e23e2a6a26185128b0/python/ray/rllib/optimizers/async_replay_optimizer.py#L39-L41

The class DQNPolicyGraph overrides the method extra_compute_grad_fetches(self) which provides the td-information. That function and extra_compute_action_feed_dict(self)seems to be unique to the class TFPolicyGraph and I assume these wouldn't need to be implemented in a pytorch version. Apart from providing td-information in compute_gradients(self, postprocessed_batch), are there any additional hard-coded {}s that would need to be filled out? It seems to not be the case to me, but I'm not that familiar with the code-base.

— You are receiving this because you were mentioned.

Reply to this email directly, view it on GitHub https://github.com/ray-project/ray/issues/3365#issuecomment-445428410, or mute the thread https://github.com/notifications/unsubscribe-auth/AAA6SgGtZY7OlIcKjQaN8BDVDiCiciJmks5u2zoVgaJpZM4Yr2-3 .

nilsjohanbjorck commented 5 years ago

Sorry on my inactivity on this issue, I haven't been able to work on it lately. I hope to get back to it eventually, but I cannot really say when I'll be able to. Sorry about that!

sven1977 commented 4 years ago

Torch PPO is learning. Waiting for 2 other PR-dependencies to get solved first before we can merge this one. APPO next.

sven1977 commented 4 years ago

I'm closing this issue.