google-deepmind / rlax

https://rlax.readthedocs.io
Apache License 2.0
1.24k stars 85 forks source link

Writing a MPO example (help I'm confused) #118

Closed act65 closed 1 year ago

act65 commented 1 year ago

I'm trying to write an example for MPO (for a categorical action space). However, I'm confused.

Mainly I'm confused about the kl_constraints arg to mpo_loss.

kl_constraints = [(rlax.categorical_kl_divergence(???, ???), lagrange_penalty)]

I dont understand what the two args to the kl div would be. (also. I dont understand why it's a list. How can there be more than one kl div?)


Afaik, this KL constraint is to be used for the M step. So should be doing something like;

$$ J(\theta) = ... + KL(π(a|s, θ_i), π(a|s, θ)) $$

However, this equation also doesnt make sense to me. Arent we evaluating the gradient of $J$ at $\theta_i$, so the KL term would be 0?

What am I missing...? (something important it seems.)

act65 commented 1 year ago

nvm... i can just ready your tests