This assertion failed at some point during training...
hack_grads = torch.cat([param.grad1.flatten(start_dim=1) for param in model.parameters()
if hasattr(param, 'grad1')], dim=1)
grads = torch.cat([param.grad.flatten() for param in model.parameters() if hasattr(param, 'grad1')], dim=0)
assert torch.allclose(torch.mean(hack_grads, dim=0), grads)
This assertion failed at some point during training...