Closed spktrm closed 1 year ago
Related post: https://github.com/deepmind/open_spiel/issues/980
@bartdevylder @perolat any follow-up on this?
Hi,
We use only 3 heads in the Stratego work and it depends on the phase of the game (deployment, piece selection and piece destination). See figure S3 of https://www.science.org/doi/abs/10.1126/science.add4679.
The policy is just the one used at the corresponding phase so not much change on the calculation of the v-trace return. Note that in this implementation we only use 1 policy head to make things simpler.
Thank you for your response. Would it even possible to optimise multiple policies at the same time this way?
Thank you for your response. Would it even possible to optimise multiple policies at the same time this way?
Sure, you can just assign zero loss to the policy heads that do not correspond to the current phase (i.e. no gradients for those heads because they don't change, and they can output whatever they want because they won't be used in the current phase).
@lanctot sorry to reopen this. What does this mean for value targets? I understand masking the policy, but do I have a value head for each policy as well? Also for calculating the vtrace advantage, I not the target values become zeros on invalid states. Are states that do not correspond to the current policy considered invalid?
I am guessing they just used one value network, but I will ask them to reply.
Hi, we indeed have 3 policy-heads but only one value head and so also for v-trace only one value head is take into account. In the provided implementation all normal games states are considered 'valid', except a terminal state. The concept of 'valid' is introduced to cope with batching trajectories: not all trajectories are of the same length and so are padded to fit them into a single-tensor batch, where the 'valid' mask indicates which part of this input is actual data and which part is just padding.
I guess i should be more specific with my use case. If I had multiple policy heads that all acted at the same time, how would I go about calculating the value head. These heads are autoregressive, meaning that sometimes they are not used for a particular step. I understand I'd just mask that step for the policy loss, but would I also mask this for the respective value head loss? Or would I just have one value head with additional rewards for each step based on which policies were used?
I'm not sure hat you mean by multiple policy heads that all acted at the same time
. You only get to play one action, right, so somehow you have to pass them through a filter, let's say f(s, \pi_1, \pi_2, \pi_3). In the stratego paper, that filter used the phase of s to choose which pi, but you can think of it as some general function.
But notice: f is just (an aggregated) policy, itself. So you'd learn a single value that is estimating the expected returns under that policy, i.e. V^f(s).
Value heads are trained via regression, so it's a v-network there's nothing to mask-- it's the one value. If it's a Q-network, then you set the loss on all the actions to 0 except the one you sampled, where you use the Bellman error wrt. the Q-value at the next state.
(I might be misunderstanding here.. maybe @bartdevylder can pipe in if so...)
I'm more talking use cases beyond what is mentioned by the paper. In my case, I have multiple policy heads to deconstruct the action space, since it is combinatorial. As as result, not every policy head is used in every step.
By masking I mean since some policy heads are not used to generate steps along the trajectory, how should the vtrace calculation be changed to account for this ie what should the v target be?
In the RNaD paper, there are 4 policy heads and one value head. I calculate the vtrace returns for each policy head. This produces its own policy target and value target. Do I optimise the value head 4 times for each value target? Or do I sum / average the value targets to create a one value target?
Could I simply multiply the importance sampling ratios before doing the vtrace calc?
Could the authors provide a an example?