We normalize the gradients when calulating the Cs, but don't yet do it for the Edges.
This is the current code:
# Calculate the square and sum over the pos dimension if it exists.
f_out_hat_norm: Float[Tensor, "... out_hidden_combined_trunc"] = f_out_hat**2
if has_pos:
# f_out_hat is shape (batch, pos, hidden)
assert f_out_hat.dim() == 3, f"f_out_hat should have 3 dims, got {f_out_hat.dim()}"
f_out_hat_norm = f_out_hat_norm.sum(dim=1)
# Sum over the batch dimension
f_out_hat_norm = f_out_hat_norm.sum(dim=0)
Instead, I think we should take the mean over these dimensions.
First investigate the sizes before and after this adjustment and see if it has a stabilising effect.
We normalize the gradients when calulating the Cs, but don't yet do it for the Edges.
This is the current code:
Instead, I think we should take the mean over these dimensions.
First investigate the sizes before and after this adjustment and see if it has a stabilising effect.