f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
561 stars 55 forks source link

[ADD] backpack: add option grad_exists_behaviour #218

Open schaefertim opened 3 years ago

schaefertim commented 3 years ago

Currently, the old BackPACK values on parameters are silently overwritten. I suggest making this process explicit and adding alternative options.

This draft adds an option grad_exists_behaviour to backpack(). Possible values are overwrite, error, sum, append. Default: overwrite (same as current behaviour).

The error option, is useful for making sure users don't overwrite a quantity by accident where a different bahviour was expected. I would really like to make this the default option, but there are two issues:

The sum and append options are useful if more than one batch is passed through. They define how to accumulate then.

The append option is useful for example in extensions with results for each sample:

with backpack(BatchGrad()):
  loss = forward(x1)
  loss.backward()
with backpack(BatchGrad(), grad_exists_behaviour='append'):
  loss = forward(x2)
  loss.backward()

The sum option is useful in similar cases, but where no individual quantities are calculated (PyTorch's .grad does this).

Additionally, in the future, if BackPACK should support loops over a single module, it requires the option sum in some form.