Closed n30111 closed 2 years ago
Hey @n30111 , thanks for raising this. I think the answer here is that in case you do want to use a action_sampler_fn
(in which you take charge of action computation entirely w/o the help of the policy's built-in action-dist/sampling utilities), you have to make sure that your loss function handles this absence of an action-distribution class.
From looking at your action_sampler_fn
, it seems that all you are trying to do is to return a deterministic action (instead of a sampled one from the distribution). You can also achieve that by setting config.explore=False
in PPO.
However, if you are trying to do more complex things in your custom action_sampler_fn
, you would need to also re-define your loss to handle the dist_class=None issue.
To summarize RLlib's behavior:
action_sampler_fn
defined: RLlib will NOT create an action-dist class for you; RLlib will NOT create an action_dist_inputs
placeholder for you; you are responsible for coming up with actions from this custom function.action_distribution_fn
defined: Return an action-dist input tensor, a action-dist class, and state-outs (or []) from this custom function, RLlib will do the rest (sample from the given distribution class for action calculations).Thanks @sven1977, we were trying to do additional things (not only deterministic) in the custom action_sampler_fn
. Since it worked for SAC, I was expecting it would work for PPO also.
I was able to make it work by making minor changes in the ppo_surrogate_loss
and make_model
, and with action_samper_fn
output signature changes in the policy class.
@sven1977 Can you please look at this commit https://github.com/minds-ai/ray/commit/eba38ecc8b7e4eeeacb95d1ca93a0c72a343b5d3 and let me know if RLLib will accept this change?
Hi @sven1977 please let us know your thoughts on this.
added some comments to your commit. let's move the discussion there. thanks.
Search before asking
Ray Component
RLlib
What happened + What you expected to happen
PPO does not work while using
action_sampler_fn
andmake_model
.Versions / Dependencies
Python =3.8 ray==1.9.2
Reproduction script
Anything else
No response
Are you willing to submit a PR?