Open m-lyon opened 1 year ago
Hi @m-lyon,
thanks for your question. May I ask which BackPACK extensions you are planning to use?
Of course, my apologies for not including that info.
I'm attempting to use the Laplace framework for a model, in my current implementation it's using the DiagGGNExact
extension.
The DiagGGNExact
is what we call a 'second-order' extension. This means that your model must entirely consist of layers known to BackPACK and it must be provided as a collection of modules without overwriting forward
. A first step to get there would be to re-write your code as
custom_module = torch.nn.Sequential(
torch.nn.Sequential(
# layers of 'OtherCustomModule'
),
torch.nn.Sequential(
# layers of 'AnotherCustomModule'
),
torch.nn.Sequential(
# layers of 'OtherCustomModule'
),
)
Now you can fill in the layers that are already supported by BackPACK, as well as the ones you wish to implement yourself. You can take a look at this issue where I've outlined how to add support for a second-order extension.
As an alternative, you could also consider looking into extending the alternative backend of the Laplace library (ASDL).
Let me know if that helps.
Thank you for pointing me in the right direction. I'll take a look at the second order extension example and see if I can't implement something myself. Unfortunately the Laplace framework currently does not support ASDL for regression, only classification, so i'd assume that would be an equal amount of work.
Thanks
This means that your model must entirely consist of layers known to BackPACK and it must be provided as a collection of modules without overwriting forward
It doesn't seem completely obvious how I would do this because my network (and subsequently the custom layers within my network) have several inputs. Therefore without modifying forward
i'm not sure how I would do that.
Conceptually my model looks something like this
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = CustomLayer1()
self.layer2 = CustomLayer2()
self.layer3 = CustomLayer3()
def forward(self, tensor1, tensor2, tensor3):
out = self.layer1(tensor1, tensor2, tensor2)
out = self.layer2(out, tensor2, tensor3)
out = self.layer3(out, tensor3, tensor3)
return out
Hey, thanks for the update!
You are right that this is indeed challenging for your model. The problem with the above forward pass is that tensor2, tensor3
are used by multiple layers. How complex are layer1, layer2, layer3
? Would it be easy to fuse them into a single layer, or are they also deep layers?
They are somewhat complex. I think being able to break the problem down into smaller parts would be a much more feasible solution rather than fusing them into one, especially considering that the number of layers in reality is much greater than in this example.
I think something that is quite important to solving this problem, which i'm currently unsure how to do, is writing extensions for nn.Modules
that have submodules instead of nn.Parameters
. For example my CustomLayer
s for the most part do some operations on the incoming data, then pass the resulting tensor onto a nn.Linear
layer. Obviously if there was only one input then I could use the nn.Sequential
container to achieve this, however as we discussed that isn't the case.
I've taken a look at the SumModule
extension with DiagGGNSumModule
, and its derivates class SumModuleDerivatives
, to give me some clues how I would implement this (since this takes multiple inputs and does not have any nn.Parameter
s itself). Looking at this it seems the only involved thing is implementing the _jac_t_mat_prod
method? Am i on the right track here or not really?
I think something that is quite important to solving this problem, which i'm currently unsure how to do, is writing extensions for nn.Modules that have submodules instead of nn.Parameters.
BackPACK's design somewhat conflicts with this feature, the problem being that a module which does not act like a container (such as Sequential
) can build up complicated graphs internally through which it is hard to backpropagate extra information via Module
hooks.
In your case, the container Network
is still 'well-behaved' in that it does not use any nn.functional
s inside the forward pass which would build up boundaries for backpropagation via module hooks.
The short answer is that in such case it is still possible to get the backpropagation working (extensive answer below).
I put together a self-consistent example below.
It demonstrates how to add support for a layer MultiplyModule
which takes multiple tensors as input. It multiplies all its inputs, then multiplies its weight and returns the result (this leads to a simple Jacobian implementation). Then, I used a SimpleContainerModule
which contains multiple MultiplyModule
s and coordinates their forward pass without calling any other nn.functional
s (similar to your above example code). I double-check the GGN diagonal with autograd
and it matches BackPACK's result.
I had to perform a slight adjustment to BackPACK's backpropagation mechanism. You will have to install from the multiple-inputs
branch.
I think this should get you started. There might be additional pitfalls on the way, though. Let me know if you run into issues.
Here is the code:
"""Second-order extensions for modules with multiple inputs and slightly advanced control flow containers."""
from typing import List, Tuple
from torch import Tensor, allclose, einsum, manual_seed, rand, zeros
from torch.nn import Module, MSELoss, Parameter
from torch.nn.utils.convert_parameters import parameters_to_vector
from backpack import backpack, extend
from backpack.extensions import DiagGGNExact
from backpack.extensions.module_extension import ModuleExtension
from backpack.hessianfree.ggnvp import ggn_vector_product
from backpack.utils.convert_parameters import vector_to_parameter_list
class MultiplyModule(Module):
"""Module that multiplies all its inputs with itself and a weight."""
def __init__(self, weight: float = 1.0):
super().__init__()
self.weight = Parameter(Tensor([weight]))
def forward(self, *inputs: Tensor):
"""Multiply all inputs, then multiply by the weight and return result."""
# accept batched scalars only for simplicity
assert len({i.shape[0] for i in inputs}) == 1
assert all(i.dim() == 2 and i.shape[1] == 1 for i in inputs)
result = self.weight
for i in inputs:
result = result * i
return result
class DiagGGNMultiplyModule(ModuleExtension):
"""Describes how to compute the GGN diagonal for a ``MultiplyModule``."""
def __init__(self):
super().__init__(params=["weight"])
def backpropagate(
self,
ext: DiagGGNExact,
module: MultiplyModule,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
bpQuantities: Tensor,
) -> Tuple[Tensor]:
"""Backprop GGN matrix square root from output to inputs of ``MultiplyModule``.
This multiplies the backpropagated object with output-input Jacobian for each
input.
Returns a tuple with the backpropagated GGN matrix square root for each input.
"""
sqrt_ggn = bpQuantities
inputs = self.get_inputs(module) # stored by BackPACK in the forward pass
backpropagate_to_inputs = []
# apply the output-input Jacobian for all inputs
for i in range(len(inputs)):
other = [inp for j, inp in enumerate(inputs) if j != i]
jac_inp_i = module.weight
for inp_other in other:
jac_inp_i = jac_inp_i * inp_other
backpropagate_to_inputs.append(sqrt_ggn * jac_inp_i)
# tuple signalizes each entry will be backpropped to its input
return tuple(backpropagate_to_inputs)
def weight(
self,
ext: DiagGGNExact,
module: MultiplyModule,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
bpQuantities: Tensor,
) -> Tensor:
"""Compute the GGN diagonal for the weight of ``MultiplyModule``."""
sqrt_ggn = bpQuantities
inputs = self.get_inputs(module)
jac = inputs.pop()
while inputs:
jac = jac * inputs.pop()
jac_sqrt_ggn = einsum("vni,ni->vni", sqrt_ggn, jac)
return einsum("vni,vni->i", jac_sqrt_ggn, jac_sqrt_ggn)
@staticmethod
def get_inputs(module: Module) -> List[Tensor]:
"""Get all inputs of ``MultiplyModule``'s forward pass."""
layer_inputs = []
i = 0
while hasattr(module, f"input{i}"):
layer_inputs.append(getattr(module, f"input{i}"))
i += 1
return layer_inputs
class MySimpleContainer(Module):
"""Container module that feeds outputs of submodules into other children.
This is okay as long as there are no calls to ``nn.functional``'s in the
``forward`` method.
"""
def __init__(self) -> None:
super().__init__()
self.layer1 = MultiplyModule(0.5)
self.layer2 = MultiplyModule(-1.5)
self.layer3 = MultiplyModule(3.0)
def forward(self, x1, x2, x3):
out = self.layer1(x1, x2, x2)
out = self.layer2(out, x2, x3)
return self.layer3(out, x3, x3)
###############################################################################
# Set up toy problem #
###############################################################################
batch_size = 10
manual_seed(0)
x1 = rand(batch_size, 1)
x2 = rand(batch_size, 1)
x3 = rand(batch_size, 1)
label = rand(batch_size, 1)
model = MySimpleContainer()
loss_func = MSELoss()
###############################################################################
# Compute the GGN diagonal with autograd #
###############################################################################
output = model(x1, x2, x3)
loss = loss_func(output, label)
parameters = list(model.parameters())
num_params = sum(p.numel() for p in parameters)
diag_ggn_autograd = zeros(num_params)
# compute GGN column by column, extracting the diagonal element
for i in range(num_params):
e_i = zeros(num_params)
e_i[i] = 1.0
e_i_list = vector_to_parameter_list(e_i, parameters)
ggn_diag_i_list = ggn_vector_product(loss, output, model, e_i_list)
diag_ggn_autograd[i] = parameters_to_vector(ggn_diag_i_list)[i]
###############################################################################
# Compute the GGN diagonal with BackPACK #
###############################################################################
# only extend sub-modules; BackPACK does not know ``MySimpleContainer``
for submodule in model.children():
extend(submodule)
loss_func = extend(loss_func)
loss = loss_func(model(x1, x2, x3), label)
ext = DiagGGNExact()
# tell the extension for the GGN diagonal to use ``DiagGGNMultiplyModule`` when it
# encounters ``MultiplyModule``.
ext.set_module_extension(MultiplyModule, DiagGGNMultiplyModule())
with backpack(ext):
loss.backward()
diag_ggn_backpack = parameters_to_vector([p.diag_ggn_exact for p in parameters])
###############################################################################
# Compare #
###############################################################################
assert allclose(diag_ggn_autograd, diag_ggn_backpack)
Thanks for this example, i'll have a play around with it and see if i can't intuit how to extend this to my problem.
In your case, the container Network is still 'well-behaved' in that it does not use any nn.functionals inside the forward pass which would build up boundaries for backpropagation via module hooks.
So while this is true for the simplistic example I gave, the real module is pretty involved. To give a clearer idea of the complexity the module (called a PCConv
) i'm working on, the implementation can be found here. There are many manipulations of data going on such as, torch.cat
, torch.unfold
, torch.expand
, torch.nn.functional.pad
, torch.permute
, torch.cartesian_prod
, things like this. On top of this the network would be a deep network comprised of many of these PCConv
layers.
Hence why I was hoping in being able to break down the problem, for each step I could figure out what I needed to implement for that given function call within a PCConv
forward pass.
So while this is true for the simplistic example I gave, the real module is pretty involved. To give a clearer idea of the complexity the module (called a PCConv) i'm working on
That's okay, but since you are implementing this as a Module
I think it might be best to not break it further down (unless this can be done using standard containers such as Sequential
with easier submodules). I also wanted to point you to BackPACK's custom_module
module which contains a supported nn.Module
for torch.permute
. There is also a modular version of nn.functional.pad
.
Am I correct that the above example at least solves your problem of handling the data flow of input tensors to the module?
Am I correct that the above example at least solves your problem of handling the data flow of input tensors to the module?
Yes, as far as I can tell from looking at the example, though I haven't tested this yet.
I also wanted to point you to BackPACK's custom_module module which contains a supported nn.Module for torch.permute. There is also a modular version of nn.functional.pad.
Thanks for pointing this out. So, as far as I understand (please correct me if i've misunderstood) I need to implement the _jac_t_mat_prod
function for the PCConv
. One point of confusion for me is that when looking at the Pad
and Permute
examples, the _jac_t_mat_prod
of these modules do the reverse operation, (i.e. remove pad & reverse the permutation exactly). However the _jac_t_mat_prod
of a Linear
layer doesn't reverse the operation. Within LinearDerivatives._jac_t_mat_prod
you end up with a Tensor of the same shape as the input, but not the same values. So i'm not sure what i'm doing here to implement my own _jac_t_mat_prod
for an arbitrary op.
Additionally, my PCConv
layer has a submodule nn.Linear
layer, presumably within the PCConv._jac_t_mat_prod
, at the relevant point in the code I can just call the LinearDerivatives._jac_t_mat_prod
? Does this appropriately handle the weights
and bias
parameters within the nn.Linear
layer?
One point of confusion for me is that when looking at the Pad and Permute examples, the _jac_t_mat_prod of these modules do the reverse operation, (i.e. remove pad & reverse the permutation exactly). However the _jac_t_mat_prod of a Linear layer doesn't reverse the operation.
The operation performed in _jac_t_mat_prod
is multiplication with the transpose Jacobian. For padding and permutation, this is unpadding and unpermuting. For the linear layer whose Jacobian is the weight matrix, the _jac_t_mat_prod
is multiplication by the transpose weight.
I think the way to progress on this is to first formulate your layer in terms of modules.
For instance, it would be good to decide whether you want the pad, cat, unfold, permute
operations to be nn.functional
's within the forward pass of PCConv
, or whether you can write PCConv
as a Sequential
of modules that do each of these operations. Once you've fixed what will be an nn.Module
and what won't, you can implement the _jac_t_mat_prod
operations for each layer. (Btw you don't have to implement separate methods, but can simply implement the backpropagate
method, which is what I did in the above example). In BackPACK we just have this abstraction of derivative operations for recycling purposes among extension.)
It seems like given the amount of operations in the forward pass, that refactoring these into a Sequential
module wouldn't gain me much time/simplicity, over having to implement a function that backpropagates the steps in the forward pass.
For padding and permutation, this is unpadding and unpermuting. For the linear layer whose Jacobian is the weight matrix, the _jac_t_mat_prod is multiplication by the transpose weight.
I understand, in that case i'm unsure how to derive the jacobian for the kind of matrix manipulation operations like pad
and permute
, cat
, etc. Do you have any resources you can point to to clear that up? Presumably if these were trivial you would have already included them in this package? (like you have with permute
and pad
)
Do you have any resources you can point to to clear that up?
You can try to take a look at the documentation of jac_t_mat_prod
. All you have to do is derive the Jacobian and then multiply its transpose onto a vector. My thesis contains some examples of Jacobians and how to derive them (Table 2.2).
Presumably if these were trivial you would have already included them in this package?
BackPACK's focus is on standard DNN architectures. So if an operation is missing, chances are it can still be added at relatively little overhead. With the examples from the website we tried to simplify this procedure for others. But in the end you won't get around implementing multiplication by the Jacobian.
I am happy to give feedback and review your code if you decide to tackle this, but won't have enough time to write code. I would proceed as follows:
jac_t_mat_prod
autograd
, which we use in the tests to compare with BackPACK's implementationjac_t_mat_prod
for your layer. Use the autograd
implementation to verify it does the correct thing.Best, Felix
Thank you for your suggestions, I've set up a dev environment with backpack
and am using the existing test suite to test the jac_t_mat_prod
methods I derive. I've started from a simple Module and am working my way up in complexity.
I think this is probably a naive question but if you can verify the jac_t_mat_prod
using autograd
then why can't one just use autograd
to calculate the jac_t_mat_prod
for arbitrary functions within backpack
? is it just a matter of efficiency? If so, are there cases where it is approximately as efficient?
Hi,
great to hear you're making progress! You are right that in principle one could use autograd
to implement jac_t_mat_prod
to support arbitrary functions in BackPACK. Some caveats are
torch.enable_grad()
(otherwise no graph to differentiate through will be built up during a backward hook)jac_t_mat_prod
performs multiple VJPs in parallel. With autograd
you have to use a for
loop, or functorch
's vmap
.BackPACK uses PyTorch's python API and might be slower than the vmap
(which however needs another forward pass). But it should be faster than a for-loop since we exclusively rely on batched operations (no for
loops) in the core.
I'd be happy to merge functionality that adds support for arbitrary layers and uses autograd
internally if you come up with a general solution.
Hey Felix,
I've written the following functionality to compute the jac_t_mat_prod
for an arbitrary model for the first input
# backpack/core/derivates/model.py
from typing import List, Tuple, Optional
import torch
from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.hessianfree.lop import transposed_jacobian_vector_product
class ArbitraryModelDerivatives(BaseDerivatives):
def _jac_t_mat_prod(
self,
module: torch.nn.Module,
g_inp: Tuple[torch.Tensor],
g_out: Tuple[torch.Tensor],
mat: torch.Tensor,
subsampling: Optional[List[int]] = None,
) -> torch.Tensor:
# Just 1 input for now
if not module.input0.requires_grad:
raise RuntimeError('requires_grad needed for arbitrary jac_t_mat_prod')
return torch.stack(transposed_jacobian_vector_product(module.output, module.input0, mat))
# test_simple.py
import torch
from backpack import extend
from backpack.core.derivatives.model import ArbitraryModelDerivatives
class MyModel(torch.nn.Module):
def forward(self, x):
return torch.sin(x)
def run_test():
model = extend(MyModel())
derivs = ArbitraryModelDerivatives()
# Just one input in this case
inputs = tuple([torch.range(0, 2).view(3, 1, 1).expand(3, 3, 3)])
for inp in inputs:
inp.requires_grad = True
output = model(*inputs)
# Use a matrix of ones to compare result with known answer
mat = torch.ones(output.shape)
res = derivs._jac_t_mat_prod(model, None, None, mat, None) # shape -> (1, 3, 3, 3)
# Known answer
ans = torch.cos(inputs[0]) # shape -> (3, 3, 3)
print(torch.allclose(res, ans)) # prints 'True'
if __name__ == '__main__':
run_test()
If i've made any incorrect assumptions here or have any mistakes please let me know.
One thing i'm unsure about is the required dimensionality of mat
or vec
in this context. You mentioned previously that you'd have to do a for loop, presumably because _jac_t_mat_prod
would actually be just be a sequence of _jac_t_vec_prod
calls. In this example I've not done that, and instead have just passed through the Tensors straight into the torch.autograd.grad
call within transposed_jacobian_vector_product
. I feel like i must be missing something here because this seems a rather simple solution?
Hi,
this looks like a good start! Let me clarify some of the points I meant in my previous post:
Re for loop:
return torch.stack(transposed_jacobian_vector_product(module.output, module.input0, mat))
Here,
mat
satisfiesmat.shape == module.output.shape
. BackPACK's_jac_t_mat_prod
assumes an additional leading axis, that ismat.shape == [V, *module.output.shape]
whereV
is some integer (not necessarily 1). The for loop overV
would be as follows: return torch.stack([transposed_jacobian_vector_product(module.output, module.input0, v)[0] for v in mat])
Re another forward pass (this might not be correct but I remember running into this):
res = derivs._jac_t_mat_prod(model, None, None, mat, None) # shape -> (1, 3, 3, 3)
While this works in your test,
derivs._jac_t_mat_prod
will be called in a backward hook and I think there the graph that relatesmodule.output
tomodule.input0
has already been freed and you should get an error duringtransposed_jacobian_vector_product
. You will have to construct a new graphmodule.input0 -> module.output
with another forward pass insidederivs._jac_t_mat_prod
.
where V is some integer (not necessarily 1) So what purpose does
V
serve? If the typical "batch" dimension (i.e. the size of the mini batch during training) will be the first dimension ofmodule.output
then what doesV
represent? Or am i misunderstanding?
I've made the following edits
# backpack.core.derivates.model.py
from typing import List, Tuple, Optional
from itertools import count
import torch
from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.hessianfree.lop import transposed_jacobian_vector_product
class ArbitraryModelDerivatives(BaseDerivatives):
'''Arbitrary Model Derivative'''
def _get_num_inputs(self, module) -> List[str]:
inputs = []
for i in count():
if hasattr(module, f'input{i}'):
inputs.append(f'input{i}')
else:
break
return inputs
def _jac_t_mat_prod(
self,
module: torch.nn.Module,
g_inp: Tuple[torch.Tensor],
g_out: Tuple[torch.Tensor],
mat: torch.Tensor,
subsampling: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, ...]:
# Start with just 1 input
input_names = self._get_num_inputs(module)
res = []
for inp_name in input_names:
inp = getattr(module, inp_name)
if not inp.requires_grad:
raise RuntimeError(f'requires_grad needed for module.{inp_name} jac_t_mat_prod')
if subsampling is None:
jtmp = torch.stack(transposed_jacobian_vector_product(module.output, inp, mat))
else:
raise NotImplementedError('Subsampling not currently implemented')
res.append(jtmp)
return tuple(res)
# backpack.extensions.backprop_extension.py
class BackpropExtension(ABC):
...
def __get_module_extension(self, module: Module) -> Union[ModuleExtension, None]:
module_extension = self.__module_extensions.get(module.__class__)
if module_extension is None:
module_extension = self._get_arbitrary_extension(module)
...
def _get_arbitrary_extension(self, module):
return None # None in Abstract base class
# backpack.extensions.secondorder.diag_ggn.__init__.py
from . import model
class DiagGGN(SecondOrderBackpropExtension):
...
def _get_arbitrary_extension(self, module):
return model.DiagGGNArbitraryModel()
...
Adding all this, using a custom module that has various layers inside of it, including Linear
layers, and running with a backpack.backpack(DiagGGNExact())
call, the runtime gets to a backpack.extensions.module_extension.ModuleExtension.__call__
for a Linear layer, where the module extension has the signature backpack.extensions.secondorder.diag_ggn.linear.DiagGGNLinear
, the line
bp_quantity = self.__get_backproped_quantity(
extension, module.output, delete_old_quantities
)
evaluates to None
, and i'm not really sure why. Any suggestions?
Hi,
where V is some integer (not necessarily 1)
So what purpose does
V
serve? If the typical "batch" dimension (i.e. the size of the mini batch during training) will be the first dimension ofmodule.output
then what doesV
represent? Or am i misunderstanding?
For Hessian-related quantities, BackPACK backpropagates multiple vectors in parallel. These vectors are stacked, which yields the leading dimension V
. All vectors along V
are then identically processed.
I tweaked your example to make it work and commented on some of the details that relate to my previous posts:
"""Backpropagation through ReLU for GGN diagonal via ``torch.autograd``."""
from typing import List, Optional, Tuple
from torch import Tensor, allclose, enable_grad, rand, stack
from torch.autograd import grad
from torch.nn import Linear, Module, MSELoss, ReLU, Sequential
from torch.nn.functional import relu
from backpack import backpack, extend, extensions
from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule
class ArbitraryModelDerivatives(BaseDerivatives):
"""Arbitrary Model Derivative"""
def __init__(self, forward_func) -> None:
super().__init__()
self.forward_func = forward_func
def _jac_t_mat_prod(
self,
module: Module,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
mat: Tensor,
subsampling: Optional[List[int]] = None,
) -> Tensor:
if subsampling is not None:
raise NotImplementedError("Subsampling not currently implemented")
print("Using arbitrary model derivatives.")
# regenerate computation graph for differentiation
with enable_grad():
# NOTE: Cannot use module(module.input0) since this triggers its
# forward hook and messes up the backpropagation internals (the
# internals rely on the memory address of module.input0, but the
# old module.input0 will be overwritten during
# module(module.input0)).
re_input = module.input0.clone().detach().requires_grad_(True)
re_output = self.forward_func(re_input)
# V vectors of shape [*module.input0.shape]
vjps = [grad(re_output, re_input, v, retain_graph=True)[0] for v in mat]
return stack(vjps) # shape [V, *module.input0.shape]
class DiagGGNReLUArbitrary(DiagGGNBaseModule):
"""Implements DiagGGN backpropagation for ReLU layer using ``torch.autograd``."""
def __init__(self):
super().__init__(derivatives=ArbitraryModelDerivatives(self.forward_func))
@staticmethod
def forward_func(input0):
return relu(input0)
X, y = rand(10, 5), rand(10, 3)
model = Sequential(Linear(5, 4), ReLU(), Linear(4, 3))
loss_func = MSELoss()
model = extend(model)
loss_func = extend(loss_func)
# ground truth
with backpack(extensions.DiagGGNExact()):
loss = loss_func(model(X), y)
loss.backward()
diag_ggn = [p.diag_ggn_exact for p in model.parameters()]
# now using arbitrary derivatives under the hood
ext = extensions.DiagGGNExact()
ext.set_module_extension(
ReLU,
DiagGGNReLUArbitrary(),
overwrite=True, # force overwrite as ReLU already exists within BackPACK
)
with backpack(ext):
loss = loss_func(model(X), y)
loss.backward()
diag_ggn_arbitrary = [p.diag_ggn_exact for p in model.parameters()]
for diag, diag_arbitrary in zip(diag_ggn, diag_ggn_arbitrary):
print(allclose(diag, diag_arbitrary))
I have a somewhat complicated
torch.nn.Module
, let's say for arguments sake its structure is a bit like this:Whilst
OtherCustomModule
andAnotherCustomModule
are themselves composed of some custom functionality, there's some standard layers within them likenn.Linear
, but there's other stuff going on too.I've read that as long as the direct children are standard torch modules like
nn.Linear
thatbackpack
can detect that and deal with that, however that isn't the case here.Looking at the example custom module docs with
ScaleModuleBatchGrad
, I'm not sure how i can implement my own class here sinceself.layer1
etc arenn.Module
s notnn.Parameter
s?