f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
549 stars 55 forks source link

Extend backpack to deal with weighted sums #321

Open rm-wu opened 4 months ago

rm-wu commented 4 months ago

Hi, I would like to apply the DiagGGNExact to a custom module that does not have any parameters. I attached a minimal example to reproduce the scenario I'd like to have. The model produces a set of weights with which combine (with a weighted sum) the downstream predictions. However, when extending the nn.Module in charge of computing this weighted sum I get an error because this module should also extend for second order operations. Can anyone help me with extending this simple module?

Here there is a minimal code example of what I'm trying to achieve:

import torch
from backpack import backpack, extend, extensions

torch.manual_seed(0)

class SumModule(torch.nn.Module):
    def forward(self, x, w):
        return torch.sum(x * w, dim=-2)

B = 5    # batch size
S = 10  # number of intermediate elements
I = 16   # input size
O = 4   # output size

x = torch.randn((B, S, I)) # some simple inputs

y = torch.ones((B, O)) * torch.arange(B)[..., None] 
# some outputs, note the shape is not dependent on the intermediate outputs S

# base_model produces both the weight and an intermediate embedding for each input
base_model = torch.nn.Sequential(
    torch.nn.Linear(I, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 16)
)

# head_model produces the final prediction from the intermediate embedding
head_model = torch.nn.Sequential(
    torch.nn.Linear(15, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, O)
)

base_model = extend(base_model)
head_model = extend(head_model)
sum_module = extend(SumModule())

loss_fun = torch.nn.MSELoss()

with backpack(extensions.DiagGGNExact()):
    x = x.reshape(-1, I)
    intermediate_output = base_model(x)
    # split the base_model output into weigths w and intermediate embedding h
    w, h = torch.split(
            intermediate_output, 
            [1, 15], dim=-1
        )
    w = w.reshape(B, S, 1) # weights
    pred = head_model(h).reshape(B, S, O) # predictions
    pred_y = sum_module(pred, w)
    loss = loss_fun(y, pred_y)
    loss.backward()

Here there is the error I get running this script:

NotImplementedError: Extension saving to diag_ggn_exact does not have an extension for Module <class '__main__.SumModule'>
f-dangel commented 4 months ago

Hi, thanks for your question!

Could you let me know w.r.t. which parameters you would like to compute the GGN diagoal? Is it only the head_model's parameters, or also the base_model's?

rm-wu commented 4 months ago

Both head_model and base_model parameters

f-dangel commented 4 months ago

In that case you will have to write a module extension of DiagGGNExact for your SumModule, as well as a module that performs the torch.split. This is because the GGN diagonal is a second-order extension, meaning that your computation graph must consist entirely of nn.Modules for all of which BackPACK knows how to backpropagate the information for the GGN diagonal.

There is work on documenting how to write module extensions for new layers in #320, but I haven't had time to review and merge it yet. You could take a look at it and start from there. The steps would be

  1. Write a SplitModule layer such that x, w = split_module(intermediate_output)
  2. Write a DiagGGNExactSplitModule extension which specifies how to backpropagate information through a SplitModule when computing the GGN diagoal (see #320)
  3. Repeat step 2. but for your SumModule.
rm-wu commented 4 months ago

Given that my SumModule has to deal with two inputs not just a single one should I also fork from the multiple-inputs referred in the issue #306 ?

f-dangel commented 4 months ago

Yes, that sound right.

f-dangel commented 4 months ago

After some more thinking, I believe the split can be done using BackPACK's custom Slicing module which already supports DiagGGNExact. This should fix one problem and leave you with step 3.

rm-wu commented 4 months ago

I have updated the previous example with my implementation of the weighted SumModule. I'm not entirely sure it is correct, in particular because I found a bit confusing how the MSELossDerivatives expands the decomposition of the Hessian of the loss, is there any particular reason it is done in this way instead of backpropagate an B,C,C matrix?

Also, as you suggested, I tried to use the Slicing module in order to do the torch.split operation, however I got this error:

ValueError: Slicing the batch axis is not supported.

I think I may be using it in the wrong way right now, but I haven't found anything in the documentation and just looking into the code it seems I should specify the whole shape (batch axis included) as slice_info for the module. Do you have any suggestion about how to fix it?

Here, there is the updated version of the previous example which is failing:

from typing import Tuple, List
import torch
from torch import nn
from backpack import BackpropExtension, backpack, extend, extensions
from backpack.extensions.module_extension import ModuleExtension
from torch.nn import Module
from torch import Tensor
import einops
from backpack.custom_module.slicing import Slicing

torch.manual_seed(0)

class SumModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, w):
        return torch.sum(x * w, dim=-2)

class DiagGGNSumModule(ModuleExtension):
    def backpropagate(
            self, 
            extension: BackpropExtension, 
            module: nn.Module, 
            g_inp: Tuple[Tensor], 
            g_out: Tuple[Tensor], 
            bpQuantities: torch.Any) -> torch.Any:
        inputs = self.get_inputs(module)
        x = inputs[0]
        w = inputs[1]
        sqrt_ggn = bpQuantities

        # J_w  = x.T 
        JwTS = torch.einsum("bic, cbk -> kbi", x, sqrt_ggn)
        JwTS = einops.rearrange(JwTS, "k b s -> k (b s) 1")

        # J_c = [w .. w]
        JxT = einops.repeat(w, "b i 1 -> b i c1 c2", c1=x.shape[-1], c2=x.shape[-1])
        JxTS = torch.einsum("bsjc, cbk -> kbsj", JxT, sqrt_ggn)
        JxTS = einops.rearrange(JxTS, "k b s j -> k (b s) j")

        return tuple([JxTS, JwTS])

    @staticmethod
    def get_inputs(module: nn.Module) -> List[Tensor]:
        """Get all inputs of ``MultiplyModule``'s forward pass."""
        layer_inputs = []

        i = 0
        while hasattr(module, f"input{i}"):
            layer_inputs.append(getattr(module, f"input{i}"))
            i += 1

        return layer_inputs

B = 5    # batch size
S = 10  # number of intermediate elements
I = 16   # input size
O = 4   # output size

sum_module = extend(SumModule())

x = torch.randn((B, S, I)) # some simple inputs

y = torch.ones((B, O)) * torch.arange(B)[..., None] 
# some outputs, note the shape is not dependent on the intermediate outputs S

# base_model produces both the weight and an intermediate embedding for each input
base_model = torch.nn.Sequential(
    torch.nn.Linear(I, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 16)
)

# head_model produces the final prediction from the intermediate embedding
head_model = torch.nn.Sequential(
    torch.nn.Linear(15, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, O)
)

base_model = extend(base_model)
head_model = extend(head_model)
sum_module = extend(SumModule())
slice_0 = extend(Slicing((slice(0, B*S), 1)))
slice_1_16 = extend(Slicing((slice(0, B*S), slice(1, 16))))

ext = extensions.DiagGGNExact()
ext.set_module_extension(SumModule,  DiagGGNSumModule())

loss_fun = extend(torch.nn.MSELoss())

with backpack(ext, debug=True):
    x = x.reshape(-1, I)
    intermediate_output = base_model(x)
    # split the base_model output into weigths w and intermediate embedding h
    # w, h = torch.split(
    #         intermediate_output, 
    #         [1, 15], dim=-1
    #     )
    w = slice_0(intermediate_output)
    h = slice_1_16(intermediate_output)
    pred = head_model(h) # predictions

    w = w.reshape(B, S, 1) # weights
    pred = pred.reshape(B, S, O)

    pred_y = sum_module(pred, w)
    loss = loss_fun(y, pred_y)
    loss.backward()
f-dangel commented 4 months ago

Hi, thanks for the update.

rm-wu commented 4 months ago

Thanks for the answer! I used the slicing_info_* you suggested, however it doesn't crash only when I do the slicing in this order

h = slice_1_16(intermediate_output)
w = slice_0(intermediate_output)

while if I swap these two lines I get an error of shape mismatch. I believe the issue comes from how the branch multiple-inputs from #306 handles the saved bpQuantities, do you know if this could be the reason and how to fix it in that case?

f-dangel commented 4 months ago

Hi, I don't really understand why the order of slicing should matter. Do you get the error in the forward pass?

rm-wu commented 4 months ago

No, I get the error during the backpropagation step

f-dangel commented 4 months ago

Could you post a minimal example that reproduces your error and append the traceback?

rm-wu commented 4 months ago

Of course!

from typing import Tuple, List
import torch
from torch import nn
from backpack import BackpropExtension, backpack, extend, extensions
from backpack.extensions.module_extension import ModuleExtension
from torch.nn import Module
from torch import Tensor
import einops
from backpack.custom_module.slicing import Slicing

torch.manual_seed(0)

class SumModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, w):
        return torch.sum(x * w, dim=-2)

class DiagGGNSumModule(ModuleExtension):
    def backpropagate(
            self, 
            extension: BackpropExtension, 
            module: nn.Module, 
            g_inp: Tuple[Tensor], 
            g_out: Tuple[Tensor], 
            bpQuantities: torch.Any) -> torch.Any:
        inputs = self.get_inputs(module)
        x = inputs[0]
        w = inputs[1]
        sqrt_ggn = bpQuantities

        # J_w  = x.T 
        JwTS = torch.einsum("bic, cbk -> kbi", x, sqrt_ggn)
        JwTS = einops.rearrange(JwTS, "k b s -> k (b s) 1")

        # J_c = [w .. w]
        JxT = einops.repeat(w, "b i 1 -> b i c1 c2", c1=x.shape[-1], c2=x.shape[-1])
        JxTS = torch.einsum("bsjc, cbk -> kbsj", JxT, sqrt_ggn)
        JxTS = einops.rearrange(JxTS, "k b s j -> k (b s) j")

        return tuple([JxTS, JwTS])

    @staticmethod
    def get_inputs(module: nn.Module) -> List[Tensor]:
        """Get all inputs of ``MultiplyModule``'s forward pass."""
        layer_inputs = []

        i = 0
        while hasattr(module, f"input{i}"):
            layer_inputs.append(getattr(module, f"input{i}"))
            i += 1

        return layer_inputs

B = 5    # batch size
S = 10  # number of intermediate elements
I = 2   # input size
O = 2   # output size

sum_module = extend(SumModule())

x = torch.randn((B, S, I)) # some simple inputs

y = torch.ones((B, O)) * torch.arange(B)[..., None] 
# some outputs, note the shape is not dependent on the intermediate outputs S

# base_model produces both the weight and an intermediate embedding for each input
base_model = torch.nn.Sequential(
    torch.nn.Linear(I, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 16)
)

# head_model produces the final prediction from the intermediate embedding
head_model = torch.nn.Sequential(
    torch.nn.Linear(15, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, O)
)

base_model = extend(base_model)
head_model = extend(head_model)
sum_module = extend(SumModule())
slice_0 = extend(Slicing((slice(None), slice(0, 1))))
slice_1_16 = extend(Slicing((slice(None), slice(1, 16))))

ext = extensions.DiagGGNExact()
ext.set_module_extension(SumModule,  DiagGGNSumModule())

loss_fun = extend(torch.nn.MSELoss())

with backpack(ext, debug=True):
    x = x.reshape(-1, I)
    intermediate_output = base_model(x)
    # split the base_model output into weigths w and intermediate embedding h

    w = slice_0(intermediate_output)
    h = slice_1_16(intermediate_output)

    pred = head_model(h) # predictions

    w = w.reshape(B, S, 1) # weights
    pred = pred.reshape(B, S, O)

    pred_y = sum_module(pred, w)
    loss = loss_fun(y, pred_y)
    loss.backward()

And here the error message:

[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on MSELoss()
[DEBUG] Running extension hook on MSELoss()
[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on SumModule()
[DEBUG] Running extension hook on SumModule()
[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on Linear(in_features=64, out_features=2, bias=True)
[DEBUG] Running extension hook on Linear(in_features=64, out_features=2, bias=True)
[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on ReLU()
[DEBUG] Running extension hook on ReLU()
[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on Linear(in_features=15, out_features=64, bias=True)
[DEBUG] Running extension hook on Linear(in_features=15, out_features=64, bias=True)
[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on Slicing()
[DEBUG] Running extension hook on Slicing()
[DEBUG] Running extension <backpack.extensions.secondorder.diag_ggn.DiagGGNExact object at 0x7f9f74d6d730> on Slicing()
Traceback (most recent call last):
  File "/home/projects/miniconda3/envs/nerfbackpack/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/projects/miniconda3/envs/nerfbackpack/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/projects/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/projects/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/projects/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/projects/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/projects/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/projects/.vscode-server/extensions/ms-python.debugpy-2024.2.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/home/projects/scratch/minimal_example.py", line 111, in <module>
    loss.backward()
  File "/home/projects/miniconda3/envs/nerfbackpack/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/home/projects/miniconda3/envs/nerfbackpack/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/projects/miniconda3/envs/nerfbackpack/lib/python3.8/site-packages/torch/utils/hooks.py", line 137, in hook
    out = hook(self.module, res, self.grad_outputs)
  File "/home/projects/scratch/backpack/backpack/__init__.py", line 209, in hook_run_extensions
    backpack_extension(module, g_inp, g_out)
  File "/home/projects/scratch/backpack/backpack/extensions/backprop_extension.py", line 131, in __call__
    module_extension(self, module, g_inp, g_out)
  File "/home/projects/scratch/backpack/backpack/extensions/module_extension.py", line 125, in __call__
    bp_quantity = self.backpropagate(
  File "/home/projects/scratch/backpack/backpack/extensions/mat_to_mat_jac_base.py", line 55, in backpropagate
    return self.derivatives.jac_t_mat_prod(
  File "/home/projects/scratch/backpack/backpack/core/derivatives/shape_check.py", line 133, in _wrapped_mat_prod_accept_vectors
    mat_out = mat_prod(self, module, g_inp, g_out, mat_in, *args, **kwargs)
  File "/home/projects/scratch/backpack/backpack/core/derivatives/shape_check.py", line 189, in wrapped_mat_prod_check_shapes
    in_check(mat, module, *args, **kwargs)
  File "/home/projects/scratch/backpack/backpack/core/derivatives/shape_check.py", line 80, in _check_like
    return check_shape(mat, compare, diff=diff)
  File "/home/projects/scratch/backpack/backpack/core/derivatives/shape_check.py", line 49, in check_shape
    raise RuntimeError(
RuntimeError: ('Compared shapes [50, 16] and [50, 1] do not match. ', 'Got [2, 50, 16] and [50, 1]')
f-dangel commented 4 months ago

One thing you definitely have to fix are the current reshapes in the forward pass. These reshapes are functionals, but BackPACK's second-order extension require all operations of the forward pass to be performed through nn.Modules. Otherwise, the backpropagation mechanism will break. I believe what you're currently seeing is that one of BackPACK's internal checks for correct shapes fails because you are modifying tensors with functionals, which are 'invisible' to BackPACK.

I propose simplifying your current example, because you are trying to solve two different problems:

  1. Adding support for your custom SumModule, which has multiple inputs
  2. Slicing the same tensor twice (accumulating backpropagated quantities)

Maybe you can start as following simpler scenario which does not suffer from aspect 2.:

linear1 = Linear(...)
linear2 = Linear(...)
linear3 = Linear(...)
X1, X2 = rand(...), rand(...)

# (...) extend both

# both should already have correct shapes (no reshape)
w = linear1(X1)
h = linear2(X2)

pred = linear3(h)
pred_y = sum_module(pred, w)

loss = loss_fun(y, pred_y)
loss.backward()

Then check if you get the correct GGN diagonals for the parameters of linear1,2,3.