f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
549 stars 55 forks source link

Container modules with advanced control flow & modules with multiple inputs #306

Open m-lyon opened 1 year ago

m-lyon commented 1 year ago

I have a somewhat complicated torch.nn.Module, let's say for arguments sake its structure is a bit like this:

import torch

CustomModule(torch.nn.Module):

    def __init__(self):
        self.layer1 = OtherCustomModule()
        self.layer2 = AnotherCustomModule()
        self.layer3 = OtherCustomModule()

    def forward(self, inputs)
        out = self.layer1(inputs)
        out = self.layer2(out)
        out = self.layer3(out)
        return out

Whilst OtherCustomModule and AnotherCustomModule are themselves composed of some custom functionality, there's some standard layers within them like nn.Linear, but there's other stuff going on too.

I've read that as long as the direct children are standard torch modules like nn.Linear that backpack can detect that and deal with that, however that isn't the case here.

Looking at the example custom module docs with ScaleModuleBatchGrad, I'm not sure how i can implement my own class here since self.layer1 etc are nn.Modules not nn.Parameters?

f-dangel commented 1 year ago

Hi @m-lyon,

thanks for your question. May I ask which BackPACK extensions you are planning to use?

m-lyon commented 1 year ago

Of course, my apologies for not including that info.

I'm attempting to use the Laplace framework for a model, in my current implementation it's using the DiagGGNExact extension.

f-dangel commented 1 year ago

The DiagGGNExact is what we call a 'second-order' extension. This means that your model must entirely consist of layers known to BackPACK and it must be provided as a collection of modules without overwriting forward. A first step to get there would be to re-write your code as

custom_module = torch.nn.Sequential(
    torch.nn.Sequential(
        # layers of 'OtherCustomModule'
    ),
    torch.nn.Sequential(
        # layers of 'AnotherCustomModule'
    ),
    torch.nn.Sequential(
        # layers of 'OtherCustomModule'
    ),
)

Now you can fill in the layers that are already supported by BackPACK, as well as the ones you wish to implement yourself. You can take a look at this issue where I've outlined how to add support for a second-order extension.

As an alternative, you could also consider looking into extending the alternative backend of the Laplace library (ASDL).

Let me know if that helps.

m-lyon commented 1 year ago

Thank you for pointing me in the right direction. I'll take a look at the second order extension example and see if I can't implement something myself. Unfortunately the Laplace framework currently does not support ASDL for regression, only classification, so i'd assume that would be an equal amount of work.

Thanks

m-lyon commented 1 year ago

This means that your model must entirely consist of layers known to BackPACK and it must be provided as a collection of modules without overwriting forward

It doesn't seem completely obvious how I would do this because my network (and subsequently the custom layers within my network) have several inputs. Therefore without modifying forward i'm not sure how I would do that.

Conceptually my model looks something like this


class Network(torch.nn.Module):

  def __init__(self):
    super().__init__()
    self.layer1 = CustomLayer1()
    self.layer2 = CustomLayer2()
    self.layer3 = CustomLayer3()

  def forward(self, tensor1, tensor2, tensor3):
    out = self.layer1(tensor1, tensor2, tensor2)
    out = self.layer2(out, tensor2, tensor3)
    out = self.layer3(out, tensor3, tensor3)
    return out
f-dangel commented 1 year ago

Hey, thanks for the update!

You are right that this is indeed challenging for your model. The problem with the above forward pass is that tensor2, tensor3 are used by multiple layers. How complex are layer1, layer2, layer3? Would it be easy to fuse them into a single layer, or are they also deep layers?

m-lyon commented 1 year ago

They are somewhat complex. I think being able to break the problem down into smaller parts would be a much more feasible solution rather than fusing them into one, especially considering that the number of layers in reality is much greater than in this example.

I think something that is quite important to solving this problem, which i'm currently unsure how to do, is writing extensions for nn.Modules that have submodules instead of nn.Parameters. For example my CustomLayers for the most part do some operations on the incoming data, then pass the resulting tensor onto a nn.Linear layer. Obviously if there was only one input then I could use the nn.Sequential container to achieve this, however as we discussed that isn't the case.

I've taken a look at the SumModule extension with DiagGGNSumModule, and its derivates class SumModuleDerivatives, to give me some clues how I would implement this (since this takes multiple inputs and does not have any nn.Parameters itself). Looking at this it seems the only involved thing is implementing the _jac_t_mat_prod method? Am i on the right track here or not really?

f-dangel commented 1 year ago

I think something that is quite important to solving this problem, which i'm currently unsure how to do, is writing extensions for nn.Modules that have submodules instead of nn.Parameters.

BackPACK's design somewhat conflicts with this feature, the problem being that a module which does not act like a container (such as Sequential) can build up complicated graphs internally through which it is hard to backpropagate extra information via Module hooks.

In your case, the container Network is still 'well-behaved' in that it does not use any nn.functionals inside the forward pass which would build up boundaries for backpropagation via module hooks.

The short answer is that in such case it is still possible to get the backpropagation working (extensive answer below).

f-dangel commented 1 year ago

I put together a self-consistent example below.

It demonstrates how to add support for a layer MultiplyModule which takes multiple tensors as input. It multiplies all its inputs, then multiplies its weight and returns the result (this leads to a simple Jacobian implementation). Then, I used a SimpleContainerModule which contains multiple MultiplyModules and coordinates their forward pass without calling any other nn.functionals (similar to your above example code). I double-check the GGN diagonal with autograd and it matches BackPACK's result.

I had to perform a slight adjustment to BackPACK's backpropagation mechanism. You will have to install from the multiple-inputs branch.

I think this should get you started. There might be additional pitfalls on the way, though. Let me know if you run into issues.

Here is the code:

"""Second-order extensions for modules with multiple inputs and slightly advanced control flow containers."""
from typing import List, Tuple

from torch import Tensor, allclose, einsum, manual_seed, rand, zeros
from torch.nn import Module, MSELoss, Parameter
from torch.nn.utils.convert_parameters import parameters_to_vector

from backpack import backpack, extend
from backpack.extensions import DiagGGNExact
from backpack.extensions.module_extension import ModuleExtension
from backpack.hessianfree.ggnvp import ggn_vector_product
from backpack.utils.convert_parameters import vector_to_parameter_list

class MultiplyModule(Module):
    """Module that multiplies all its inputs with itself and a weight."""

    def __init__(self, weight: float = 1.0):
        super().__init__()
        self.weight = Parameter(Tensor([weight]))

    def forward(self, *inputs: Tensor):
        """Multiply all inputs, then multiply by the weight and return result."""
        # accept batched scalars only for simplicity
        assert len({i.shape[0] for i in inputs}) == 1
        assert all(i.dim() == 2 and i.shape[1] == 1 for i in inputs)

        result = self.weight
        for i in inputs:
            result = result * i

        return result

class DiagGGNMultiplyModule(ModuleExtension):
    """Describes how to compute the GGN diagonal for a ``MultiplyModule``."""

    def __init__(self):
        super().__init__(params=["weight"])

    def backpropagate(
        self,
        ext: DiagGGNExact,
        module: MultiplyModule,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        bpQuantities: Tensor,
    ) -> Tuple[Tensor]:
        """Backprop GGN matrix square root from output to inputs of ``MultiplyModule``.

        This multiplies the backpropagated object with output-input Jacobian for each
        input.

        Returns a tuple with the backpropagated GGN matrix square root for each input.
        """
        sqrt_ggn = bpQuantities
        inputs = self.get_inputs(module)  # stored by BackPACK in the forward pass
        backpropagate_to_inputs = []

        # apply the output-input Jacobian for all inputs
        for i in range(len(inputs)):
            other = [inp for j, inp in enumerate(inputs) if j != i]
            jac_inp_i = module.weight
            for inp_other in other:
                jac_inp_i = jac_inp_i * inp_other
            backpropagate_to_inputs.append(sqrt_ggn * jac_inp_i)

        # tuple signalizes each entry will be backpropped to its input
        return tuple(backpropagate_to_inputs)

    def weight(
        self,
        ext: DiagGGNExact,
        module: MultiplyModule,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        bpQuantities: Tensor,
    ) -> Tensor:
        """Compute the GGN diagonal for the weight of ``MultiplyModule``."""
        sqrt_ggn = bpQuantities

        inputs = self.get_inputs(module)
        jac = inputs.pop()
        while inputs:
            jac = jac * inputs.pop()

        jac_sqrt_ggn = einsum("vni,ni->vni", sqrt_ggn, jac)

        return einsum("vni,vni->i", jac_sqrt_ggn, jac_sqrt_ggn)

    @staticmethod
    def get_inputs(module: Module) -> List[Tensor]:
        """Get all inputs of ``MultiplyModule``'s forward pass."""
        layer_inputs = []

        i = 0
        while hasattr(module, f"input{i}"):
            layer_inputs.append(getattr(module, f"input{i}"))
            i += 1

        return layer_inputs

class MySimpleContainer(Module):
    """Container module that feeds outputs of submodules into other children.

    This is okay as long as there are no calls to ``nn.functional``'s in the
    ``forward`` method.
    """

    def __init__(self) -> None:
        super().__init__()
        self.layer1 = MultiplyModule(0.5)
        self.layer2 = MultiplyModule(-1.5)
        self.layer3 = MultiplyModule(3.0)

    def forward(self, x1, x2, x3):
        out = self.layer1(x1, x2, x2)
        out = self.layer2(out, x2, x3)
        return self.layer3(out, x3, x3)

###############################################################################
#                              Set up toy problem                             #
###############################################################################
batch_size = 10

manual_seed(0)
x1 = rand(batch_size, 1)
x2 = rand(batch_size, 1)
x3 = rand(batch_size, 1)
label = rand(batch_size, 1)

model = MySimpleContainer()
loss_func = MSELoss()

###############################################################################
#                    Compute the GGN diagonal with autograd                   #
###############################################################################
output = model(x1, x2, x3)
loss = loss_func(output, label)

parameters = list(model.parameters())
num_params = sum(p.numel() for p in parameters)
diag_ggn_autograd = zeros(num_params)

# compute GGN column by column, extracting the diagonal element
for i in range(num_params):
    e_i = zeros(num_params)
    e_i[i] = 1.0
    e_i_list = vector_to_parameter_list(e_i, parameters)

    ggn_diag_i_list = ggn_vector_product(loss, output, model, e_i_list)
    diag_ggn_autograd[i] = parameters_to_vector(ggn_diag_i_list)[i]

###############################################################################
#                    Compute the GGN diagonal with BackPACK                   #
###############################################################################

# only extend sub-modules; BackPACK does not know ``MySimpleContainer``
for submodule in model.children():
    extend(submodule)
loss_func = extend(loss_func)

loss = loss_func(model(x1, x2, x3), label)

ext = DiagGGNExact()
# tell the extension for the GGN diagonal to use ``DiagGGNMultiplyModule`` when it
# encounters ``MultiplyModule``.
ext.set_module_extension(MultiplyModule, DiagGGNMultiplyModule())

with backpack(ext):
    loss.backward()

diag_ggn_backpack = parameters_to_vector([p.diag_ggn_exact for p in parameters])

###############################################################################
#                                   Compare                                   #
###############################################################################
assert allclose(diag_ggn_autograd, diag_ggn_backpack)
m-lyon commented 1 year ago

Thanks for this example, i'll have a play around with it and see if i can't intuit how to extend this to my problem.

In your case, the container Network is still 'well-behaved' in that it does not use any nn.functionals inside the forward pass which would build up boundaries for backpropagation via module hooks.

So while this is true for the simplistic example I gave, the real module is pretty involved. To give a clearer idea of the complexity the module (called a PCConv) i'm working on, the implementation can be found here. There are many manipulations of data going on such as, torch.cat, torch.unfold, torch.expand, torch.nn.functional.pad, torch.permute, torch.cartesian_prod, things like this. On top of this the network would be a deep network comprised of many of these PCConv layers.

Hence why I was hoping in being able to break down the problem, for each step I could figure out what I needed to implement for that given function call within a PCConv forward pass.

f-dangel commented 1 year ago

So while this is true for the simplistic example I gave, the real module is pretty involved. To give a clearer idea of the complexity the module (called a PCConv) i'm working on

That's okay, but since you are implementing this as a Module I think it might be best to not break it further down (unless this can be done using standard containers such as Sequential with easier submodules). I also wanted to point you to BackPACK's custom_module module which contains a supported nn.Module for torch.permute. There is also a modular version of nn.functional.pad.

Am I correct that the above example at least solves your problem of handling the data flow of input tensors to the module?

m-lyon commented 1 year ago

Am I correct that the above example at least solves your problem of handling the data flow of input tensors to the module?

Yes, as far as I can tell from looking at the example, though I haven't tested this yet.

I also wanted to point you to BackPACK's custom_module module which contains a supported nn.Module for torch.permute. There is also a modular version of nn.functional.pad.

Thanks for pointing this out. So, as far as I understand (please correct me if i've misunderstood) I need to implement the _jac_t_mat_prod function for the PCConv. One point of confusion for me is that when looking at the Pad and Permute examples, the _jac_t_mat_prod of these modules do the reverse operation, (i.e. remove pad & reverse the permutation exactly). However the _jac_t_mat_prod of a Linear layer doesn't reverse the operation. Within LinearDerivatives._jac_t_mat_prod you end up with a Tensor of the same shape as the input, but not the same values. So i'm not sure what i'm doing here to implement my own _jac_t_mat_prod for an arbitrary op.

Additionally, my PCConv layer has a submodule nn.Linear layer, presumably within the PCConv._jac_t_mat_prod, at the relevant point in the code I can just call the LinearDerivatives._jac_t_mat_prod? Does this appropriately handle the weights and bias parameters within the nn.Linear layer?

f-dangel commented 1 year ago

One point of confusion for me is that when looking at the Pad and Permute examples, the _jac_t_mat_prod of these modules do the reverse operation, (i.e. remove pad & reverse the permutation exactly). However the _jac_t_mat_prod of a Linear layer doesn't reverse the operation.

The operation performed in _jac_t_mat_prod is multiplication with the transpose Jacobian. For padding and permutation, this is unpadding and unpermuting. For the linear layer whose Jacobian is the weight matrix, the _jac_t_mat_prod is multiplication by the transpose weight.

I think the way to progress on this is to first formulate your layer in terms of modules.

For instance, it would be good to decide whether you want the pad, cat, unfold, permute operations to be nn.functional's within the forward pass of PCConv, or whether you can write PCConv as a Sequential of modules that do each of these operations. Once you've fixed what will be an nn.Module and what won't, you can implement the _jac_t_mat_prod operations for each layer. (Btw you don't have to implement separate methods, but can simply implement the backpropagate method, which is what I did in the above example). In BackPACK we just have this abstraction of derivative operations for recycling purposes among extension.)

m-lyon commented 1 year ago

It seems like given the amount of operations in the forward pass, that refactoring these into a Sequential module wouldn't gain me much time/simplicity, over having to implement a function that backpropagates the steps in the forward pass.

For padding and permutation, this is unpadding and unpermuting. For the linear layer whose Jacobian is the weight matrix, the _jac_t_mat_prod is multiplication by the transpose weight.

I understand, in that case i'm unsure how to derive the jacobian for the kind of matrix manipulation operations like pad and permute, cat, etc. Do you have any resources you can point to to clear that up? Presumably if these were trivial you would have already included them in this package? (like you have with permute and pad)

f-dangel commented 1 year ago

Do you have any resources you can point to to clear that up?

You can try to take a look at the documentation of jac_t_mat_prod. All you have to do is derive the Jacobian and then multiply its transpose onto a vector. My thesis contains some examples of Jacobians and how to derive them (Table 2.2).

Presumably if these were trivial you would have already included them in this package?

BackPACK's focus is on standard DNN architectures. So if an operation is missing, chances are it can still be added at relatively little overhead. With the examples from the website we tried to simplify this procedure for others. But in the end you won't get around implementing multiplication by the Jacobian.

I am happy to give feedback and review your code if you decide to tackle this, but won't have enough time to write code. I would proceed as follows:

  1. Look at the documentation of jac_t_mat_prod
  2. Look at its implementation in autograd, which we use in the tests to compare with BackPACK's implementation
  3. Implement jac_t_mat_prod for your layer. Use the autograd implementation to verify it does the correct thing.

Best, Felix

m-lyon commented 1 year ago

Thank you for your suggestions, I've set up a dev environment with backpack and am using the existing test suite to test the jac_t_mat_prod methods I derive. I've started from a simple Module and am working my way up in complexity.

I think this is probably a naive question but if you can verify the jac_t_mat_prod using autograd then why can't one just use autograd to calculate the jac_t_mat_prod for arbitrary functions within backpack? is it just a matter of efficiency? If so, are there cases where it is approximately as efficient?

f-dangel commented 1 year ago

Hi,

great to hear you're making progress! You are right that in principle one could use autograd to implement jac_t_mat_prod to support arbitrary functions in BackPACK. Some caveats are

  1. You need to perform another forward pass because you do not have access to the layer's output (and the input-to-output graph) within a backward hook. It must be wrapped inside a torch.enable_grad() (otherwise no graph to differentiate through will be built up during a backward hook)
  2. jac_t_mat_prod performs multiple VJPs in parallel. With autograd you have to use a for loop, or functorch's vmap.

BackPACK uses PyTorch's python API and might be slower than the vmap (which however needs another forward pass). But it should be faster than a for-loop since we exclusively rely on batched operations (no for loops) in the core.

I'd be happy to merge functionality that adds support for arbitrary layers and uses autograd internally if you come up with a general solution.

m-lyon commented 1 year ago

Hey Felix,

I've written the following functionality to compute the jac_t_mat_prod for an arbitrary model for the first input

# backpack/core/derivates/model.py

from typing import List, Tuple, Optional

import torch

from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.hessianfree.lop import transposed_jacobian_vector_product

class ArbitraryModelDerivatives(BaseDerivatives):
    def _jac_t_mat_prod(
        self,
        module: torch.nn.Module,
        g_inp: Tuple[torch.Tensor],
        g_out: Tuple[torch.Tensor],
        mat: torch.Tensor,
        subsampling: Optional[List[int]] = None,
    ) -> torch.Tensor:
        #  Just 1 input for now
        if not module.input0.requires_grad:
            raise RuntimeError('requires_grad needed for arbitrary jac_t_mat_prod')
        return torch.stack(transposed_jacobian_vector_product(module.output, module.input0, mat))
# test_simple.py

import torch

from backpack import extend
from backpack.core.derivatives.model import ArbitraryModelDerivatives

class MyModel(torch.nn.Module):
    def forward(self, x):
        return torch.sin(x)

def run_test():
    model = extend(MyModel())
    derivs = ArbitraryModelDerivatives()

    # Just one input in this case
    inputs = tuple([torch.range(0, 2).view(3, 1, 1).expand(3, 3, 3)])
    for inp in inputs:
        inp.requires_grad = True

    output = model(*inputs)

    # Use a matrix of ones to compare result with known answer
    mat = torch.ones(output.shape)
    res = derivs._jac_t_mat_prod(model, None, None, mat, None) # shape -> (1, 3, 3, 3)

    # Known answer
    ans = torch.cos(inputs[0]) # shape -> (3, 3, 3)

    print(torch.allclose(res, ans)) # prints 'True'

if __name__ == '__main__':
    run_test()

If i've made any incorrect assumptions here or have any mistakes please let me know.

One thing i'm unsure about is the required dimensionality of mat or vec in this context. You mentioned previously that you'd have to do a for loop, presumably because _jac_t_mat_prod would actually be just be a sequence of _jac_t_vec_prod calls. In this example I've not done that, and instead have just passed through the Tensors straight into the torch.autograd.grad call within transposed_jacobian_vector_product. I feel like i must be missing something here because this seems a rather simple solution?

f-dangel commented 1 year ago

Hi,

this looks like a good start! Let me clarify some of the points I meant in my previous post:

m-lyon commented 1 year ago

where V is some integer (not necessarily 1) So what purpose does V serve? If the typical "batch" dimension (i.e. the size of the mini batch during training) will be the first dimension of module.output then what does V represent? Or am i misunderstanding?

m-lyon commented 1 year ago

I've made the following edits

# backpack.core.derivates.model.py

from typing import List, Tuple, Optional
from itertools import count

import torch

from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.hessianfree.lop import transposed_jacobian_vector_product

class ArbitraryModelDerivatives(BaseDerivatives):
    '''Arbitrary Model Derivative'''

    def _get_num_inputs(self, module) -> List[str]:
        inputs = []
        for i in count():
            if hasattr(module, f'input{i}'):
                inputs.append(f'input{i}')
            else:
                break
        return inputs

    def _jac_t_mat_prod(
        self,
        module: torch.nn.Module,
        g_inp: Tuple[torch.Tensor],
        g_out: Tuple[torch.Tensor],
        mat: torch.Tensor,
        subsampling: Optional[List[int]] = None,
    ) -> Tuple[torch.Tensor, ...]:
        # Start with just 1 input
        input_names = self._get_num_inputs(module)
        res = []
        for inp_name in input_names:
            inp = getattr(module, inp_name)
            if not inp.requires_grad:
                raise RuntimeError(f'requires_grad needed for module.{inp_name} jac_t_mat_prod')
            if subsampling is None:
                jtmp = torch.stack(transposed_jacobian_vector_product(module.output, inp, mat))
            else:
                raise NotImplementedError('Subsampling not currently implemented')
            res.append(jtmp)
        return tuple(res)
# backpack.extensions.backprop_extension.py

class BackpropExtension(ABC):
...

    def __get_module_extension(self, module: Module) -> Union[ModuleExtension, None]:
        module_extension = self.__module_extensions.get(module.__class__)
        if module_extension is None:
            module_extension = self._get_arbitrary_extension(module)
        ...

    def _get_arbitrary_extension(self, module):
        return None  # None in Abstract base class
# backpack.extensions.secondorder.diag_ggn.__init__.py

from . import model

class DiagGGN(SecondOrderBackpropExtension):
    ...

    def _get_arbitrary_extension(self, module):
        return model.DiagGGNArbitraryModel()
    ...

Adding all this, using a custom module that has various layers inside of it, including Linear layers, and running with a backpack.backpack(DiagGGNExact()) call, the runtime gets to a backpack.extensions.module_extension.ModuleExtension.__call__ for a Linear layer, where the module extension has the signature backpack.extensions.secondorder.diag_ggn.linear.DiagGGNLinear, the line

bp_quantity = self.__get_backproped_quantity(
            extension, module.output, delete_old_quantities
        )

evaluates to None, and i'm not really sure why. Any suggestions?

f-dangel commented 1 year ago

Hi,

where V is some integer (not necessarily 1)

So what purpose does V serve? If the typical "batch" dimension (i.e. the size of the mini batch during training) will be the first dimension of module.output then what does V represent? Or am i misunderstanding?

For Hessian-related quantities, BackPACK backpropagates multiple vectors in parallel. These vectors are stacked, which yields the leading dimension V. All vectors along V are then identically processed.

f-dangel commented 1 year ago

I tweaked your example to make it work and commented on some of the details that relate to my previous posts:

"""Backpropagation through ReLU for GGN diagonal via ``torch.autograd``."""
from typing import List, Optional, Tuple

from torch import Tensor, allclose, enable_grad, rand, stack
from torch.autograd import grad
from torch.nn import Linear, Module, MSELoss, ReLU, Sequential
from torch.nn.functional import relu

from backpack import backpack, extend, extensions
from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule

class ArbitraryModelDerivatives(BaseDerivatives):
    """Arbitrary Model Derivative"""

    def __init__(self, forward_func) -> None:
        super().__init__()
        self.forward_func = forward_func

    def _jac_t_mat_prod(
        self,
        module: Module,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        subsampling: Optional[List[int]] = None,
    ) -> Tensor:
        if subsampling is not None:
            raise NotImplementedError("Subsampling not currently implemented")

        print("Using arbitrary model derivatives.")

        # regenerate computation graph for differentiation
        with enable_grad():
            # NOTE: Cannot use module(module.input0) since this triggers its
            # forward hook and messes up the backpropagation internals (the
            # internals rely on the memory address of module.input0, but the
            # old module.input0 will be overwritten during
            # module(module.input0)).
            re_input = module.input0.clone().detach().requires_grad_(True)
            re_output = self.forward_func(re_input)

            # V vectors of shape [*module.input0.shape]
            vjps = [grad(re_output, re_input, v, retain_graph=True)[0] for v in mat]

        return stack(vjps)  # shape [V, *module.input0.shape]

class DiagGGNReLUArbitrary(DiagGGNBaseModule):
    """Implements DiagGGN backpropagation for ReLU layer using ``torch.autograd``."""

    def __init__(self):
        super().__init__(derivatives=ArbitraryModelDerivatives(self.forward_func))

    @staticmethod
    def forward_func(input0):
        return relu(input0)

X, y = rand(10, 5), rand(10, 3)

model = Sequential(Linear(5, 4), ReLU(), Linear(4, 3))
loss_func = MSELoss()

model = extend(model)
loss_func = extend(loss_func)

# ground truth
with backpack(extensions.DiagGGNExact()):
    loss = loss_func(model(X), y)
    loss.backward()

diag_ggn = [p.diag_ggn_exact for p in model.parameters()]

# now using arbitrary derivatives under the hood
ext = extensions.DiagGGNExact()
ext.set_module_extension(
    ReLU,
    DiagGGNReLUArbitrary(),
    overwrite=True,  # force overwrite as ReLU already exists within BackPACK
)
with backpack(ext):
    loss = loss_func(model(X), y)
    loss.backward()

diag_ggn_arbitrary = [p.diag_ggn_exact for p in model.parameters()]

for diag, diag_arbitrary in zip(diag_ggn, diag_ggn_arbitrary):
    print(allclose(diag, diag_arbitrary))