f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
555 stars 55 forks source link

Better error messages for BatchNorm #240

Open fKunstner opened 2 years ago

fKunstner commented 2 years ago

Makes the BatchNorm error message more explicit to avoid confusions (https://github.com/f-dangel/backpack/issues/239) and adds an option to ignore the exception.

Summary of changes:

f-dangel commented 2 years ago

Hey Fred, I skimmed through your changes:

  1. The failing test checks if the result in batch_grad for a BN layer in train mode sums to grad. Working with batch_grad is 'okay' in this case because it's not interpreted as per-sample gradients. We could either revert the default for first-oder fail mode, or adapt the test to use BatchGrad(fail_mode-"WARNING"). I would currently favor to revert the default (as this also does not trigger a version bump, and fixes 2.).
  2. The RTD example with the custom ResNet fails for similar reasons as in 1.
  3. Can you pip install --upgrade && make black to update the formatting?

Happy to review or discuss!

fKunstner commented 2 years ago

Thanks for the check!

I'd lean more towards crash that warn, but to get to something we can 👍; How about, starting from this setup;


The failing test checks if the result in batch_grad for a BN layer in train mode sums to grad. Working with batch_grad is 'okay' in this case because it's not interpreted as per-sample gradients

I don't follow the "batch_grad is okay". Do you mean in the context of the tests? If so I agree that BatchGrad should sum to Grad with or without batchnorm. But I don't think this should be the default behavior of the user-facing API. Someone calling batch_grad is expecting individual gradients and should get an error (maybe a strong warning works as well).

Can you pip install --upgrade && make black to update the formatting?

The files that black complains about are not part of this pr(?). I'll merge main in there again.