f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
548 stars 55 forks source link

Request: Extension for GroupNorm #328

Open ParthS007 opened 1 month ago

ParthS007 commented 1 month ago

I am re-implementing the enhancement of DP-SGD through the random sparsification of gradients on my UNet Model.

Here is a Debug info on extending the Model extend(model).

UNet(
  (encoder1): Sequential(
    (enc1conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc1norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
    (enc1relu1): ReLU()
    (enc1conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc1norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
    (enc1relu2): ReLU()
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  ....

  (bottleneck): Sequential(
    (bottleneckconv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bottlenecknorm1): GroupNorm(32, 512, eps=1e-05, affine=True)
    (bottleneckrelu1): ReLU()
    (bottleneckconv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bottlenecknorm2): GroupNorm(32, 512, eps=1e-05, affine=True)
    (bottleneckrelu2): ReLU()
  )
  (upconv4): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
  (decoder4): Sequential(
    (dec4conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec4norm1): GroupNorm(32, 256, eps=1e-05, affine=True)
    (dec4relu1): ReLU()
    (dec4conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (dec4norm2): GroupNorm(32, 256, eps=1e-05, affine=True)
    (dec4relu2): ReLU()
  )
  ...
)

BackPACK library does not support some of the modules in the model, specifically **GroupNorm**.

For this should I be creating custom extensions for the unsupported modules?

Logs when training the model.


env/lib/python3.11/site-packages/backpack/extensions/backprop_extension.py:106: UserWarning: Extension saving to grad_batch does not have an extension for Module <class 'networks.UNet'> although the module has parameters
  warnings.warn(
env/lib/python3.11/site-packages/backpack/extensions/backprop_extension.py:106: UserWarning: Extension saving to grad_batch does not have an extension for Module <class 'torch.nn.modules.normalization.GroupNorm'> although the module has parameters

Thanks for the help :)

f-dangel commented 1 month ago

Hi Parth,

I assume from your logs that you would like to support BatchGrad for nn.GroupNorm. You can follow the instructions here to achieve that. A PR adding this to BackPACK would be really cool, too :)

Best, Felix