Closed act65 closed 1 year ago
I'm trying to write an example for MPO (for a categorical action space). However, I'm confused.
Mainly I'm confused about the kl_constraints arg to mpo_loss.
kl_constraints
mpo_loss
kl_constraints = [(rlax.categorical_kl_divergence(???, ???), lagrange_penalty)]
I dont understand what the two args to the kl div would be. (also. I dont understand why it's a list. How can there be more than one kl div?)
Afaik, this KL constraint is to be used for the M step. So should be doing something like;
$$ J(\theta) = ... + KL(π(a|s, θ_i), π(a|s, θ)) $$
However, this equation also doesnt make sense to me. Arent we evaluating the gradient of $J$ at $\theta_i$, so the KL term would be 0?
What am I missing...? (something important it seems.)
nvm... i can just ready your tests
I'm trying to write an example for MPO (for a categorical action space). However, I'm confused.
Mainly I'm confused about the
kl_constraints
arg tompo_loss
.I dont understand what the two args to the kl div would be. (also. I dont understand why it's a list. How can there be more than one kl div?)
Afaik, this KL constraint is to be used for the M step. So should be doing something like;
$$ J(\theta) = ... + KL(π(a|s, θ_i), π(a|s, θ)) $$
However, this equation also doesnt make sense to me. Arent we evaluating the gradient of $J$ at $\theta_i$, so the KL term would be 0?
What am I missing...? (something important it seems.)