interpreting-rl-behavior / interpreting-rl-behavior.github.io

Code for the site https://interpreting-rl-behavior.github.io/
Creative Commons Attribution 4.0 International
0 stars 0 forks source link

Grads for hx in the panel seem to extend up to the 12th timestep, whereas we only calculate them from the 11th timestep #28

Closed leesharkey closed 3 years ago

leesharkey commented 3 years ago

Just posting this note from slack by Nicholas Goldowsky-Dill 6:49 PM

a = np.load("train-procgen-pytorch/generative/recorded_informinit_gen_samples/sample_00000/grad_hx_action.npy") a[12,0]

returns a non-negative value, so I think it’s upstream of the code that generates that npy file

leesharkey commented 3 years ago

In saliency_exps.py, where the grads are calculated we have that the 12th timestep (0-indexed) for 'obs' has 0 gradient but the 12th timestep for 'hx' has non-zero gradient.

You can see this from the grad_dict that is returned by the function forward_backward_pass on line 291.

This indicates that the problem, somehow, is in the model. Not even sure how that's possible, which is worrying. Potential bug. @danbraunai Let's try to identify the cause of this.

leesharkey commented 3 years ago

I've noticed that my above comment applies when doing saliency for value but not for hx_direction. For hx_direction, gradients for hx and obs are as you'd expect, where the last gradient is on the 11th timestep (0-indexed) for both. It's as though when we take the gradient of the value we're actually taking the gradient through the value from a timestep ahead....

_(Secondary note: I wonder if this is connected to something I noticed during interpretation: there seemed to be very few PCs that had large increases before jumps. There were many that had large increases immediately after jumps. This was unexpected because I'd have thought that the hx neurons would have large changes before jumps in order to have precise control of jump initiation. I'm not sure it's a bug. It may well be that most directions really do just change massively after jumps (e.g. in response to large accompanying visual changes), but it'd be a surprising find if that's the case. I suspect that if there's a bug here we should follow that up next.)_

leesharkey commented 3 years ago

All right, so I think I've found the issue. I haven't implemented a fix yet.

Action and value are calculated in the train-procgen-pytorch/common/policy.py script. We can see that the hidden state that is fed to the agent is not the hidden state that is used to produce the value and action. The hidden state is first passed through the GRU and then the next hidden state is used to produce the value and action.

    def forward(self, x, hx, masks, retain_grads=False):
        hidden = self.embedder(x)
        if self.recurrent:
            # Fill in init hx to get grads right (it's a hacky solution to use
            #  trainable initial hidden states, but it's hard to get it to work
            #  with this Kostrikov repo since it uses numpy so much).
            if not retain_grads:
                inithx_mask = [torch.all(hx[i] == self.init_hx) for i in
                               range(hx.shape[0])]
                hx[inithx_mask] = self.init_hx
            hidden, hx = self.gru(hidden, hx, masks)
        logits = self.fc_policy(hidden)
        log_probs = F.log_softmax(logits, dim=1)
        p = Categorical(logits=log_probs)
        v = self.fc_value(hidden).reshape(-1)

This might be causing issues elsewhere. I need to check whether the hx_t/value_t/action_t/obs_t etc are being trained with all the same t in the generative model. As a short term fix, i think the solution for value and action saliency calculation will be to take the gradient from timesteps - 1 instead of timesteps. Will do later today.

leesharkey commented 3 years ago

All right, I've checked in record.py & generative_models.py and this discrepancy shouldn't be causing damage to generative model training, but I'm not 100% confident of that yet. The reason I think it shouldn't be causing damage is that the discrepancy is consistent everywhere. It's still an annoying thing to find this late: Like, why wouldn't every agent want to produce a v_0 and a_0 from its h_0?

Anyway, the fix I've implemented in saliency_exps.py is the one I described in the comment above.