f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
555 stars 55 forks source link

Support for Custom models? #269

Open anegi19 opened 1 year ago

anegi19 commented 1 year ago

Hi, I cannot find it in the docs for the package, but how is extend(model) actually implemented such that it computes the higher order extensions for the parameters of the model? Basically if I want to calculate the Generalised Gauss Newton matrix using backpack for a model which doesn't actually use the torch.nn layers, but rather a custom Forward operation... is it possible ?

For example -->

class MyModel(nn.Module):
    """
    Define the Forward Model of the Experiment
    Args
    -------  
    X (Tensor),  Y (Tensor)

    Returns
    -------
    Z : Tensor

    """
    def __init__(self, X,  Y ):

        super().__init__()
        self.X = torch.nn.Parameter(X)
        self.Y = torch.nn.Parameter(Y)

    def forward( self, X_0):
         #some operation on X and Y
         return  ( self.X- X_0 )*self.Y 

and then -->

x = torch.tensor([1.,2.,3.])
y = torch.tensor([11.,22.,33.])

x0 = torch.tensor([0.5])

inputs = torch.tensor([10.,18.,13.])

model = MyModel(x,y)
model = extend(model)
cost_function = extend(torch.nn.MSEloss())

preds = model(x0)
cost = cost_function(preds, inputs)
with backpack(extensions.GGNMP()):
    cost.backward()

It doesn't work, and I get the following error -->

Extension saving to ggnmp does not have an extension for Module <class 'main.MyModel'>

Should I modify the extend(model) somehow? Another question, does backpack support Complex Tensors for computation of these higher order extensions?

f-dangel commented 1 year ago

Hi, thanks for your question!

In principle, you can add support for your custom layer to BackPACK. We have an example in the documentation that walks you through the process.

Let me know if you run into issues.

Best, Felix