f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
549 stars 55 forks source link

Are customized loss functions supported? #299

Closed zzpustc closed 3 months ago

zzpustc commented 1 year ago

Hi!

Thanks for the contributions, I would like to know if customized loss functions are supported. Currently, it seems that only those defined in PyTorch are supported.

Best

fKunstner commented 1 year ago

For first-order extensions (eg individual gradients) all loss functions or arbitrary transformations are supported (see here).

Transformations with parameters that need gradients (see making a custom module) or second-order quantities needs more a few tweaks to propagate/extract the right quantities.

Do you have a specific use-case in mind?

zzpustc commented 1 year ago

I want to implement a normalized linear layer(e.g. NormedLinear)that can be supported by KFAC in BackPack. What should I do? Are there any templates that I can refer to?

fKunstner commented 1 year ago

The docs have an example on how to implement a custom module for first-order extensions. Second-order extensions have a bit more moving parts. Below is a script that implements the KFAC operations for a linear layer (no bias). kfac_linear_layer.txt


General imports

from typing import Tuple, List

import torch
from torch import Tensor, einsum

from backpack import backpack, extend
from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.extensions import KFAC
from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule

torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Define a Module class with the new operation

class MyLinear(torch.nn.Module):
    """Torch.nn.Module class defining the operation"""

    def __init__(self, in_features, out_features):
        super(MyLinear, self).__init__()
        self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)

    def forward(self, x):
        return torch.functional.F.linear(x, self.weight)

Define a class containing derivative operations.

The only one that is strictly necessary is the matrix product between the Jacobian transposed and an arbitrary matrix, jac_t_mat_prod.

class MyLinearDerivatives(BaseParameterDerivatives):
    """Partial derivatives for MyLinear module."""

    def hessian_is_zero(self, module: MyLinear) -> bool:
        return True

    def _jac_t_mat_prod(
        self,
        module: MyLinear,
        g_inp: Tuple[Tensor],
        g_out: Tuple[Tensor],
        mat: Tensor,
        subsampling: List[int] = None,
    ) -> Tensor:
        """Batch-apply the transposed Jacobian of the output w.r.t. the input to the matrix `mat`.

        Args:
            module: MyLinear layer.
            g_inp: Gradients w.r.t. module input. Not required by the implementation.
            g_out: Gradients w.r.t. module output. Not required by the implementation.
            mat: Batch of ``V`` vectors of same shape as the layer output
                (``[N, *, out_features]``) to which the transposed output-input Jacobian
                is applied. Has shape ``[V, N, *, out_features]``; but if used with
                sub-sampling, ``N`` is replaced by ``len(subsampling)``.
            subsampling: Indices of active samples. ``None`` means all samples.

        Returns:
            Batched transposed Jacobian vector products. Has shape
            ``[V, N, *, in_features]``. If used with sub-sampling, ``N`` is replaced
            by ``len(subsampling)``.
        """
        return einsum("vn...o,oi->vn...i", mat, module.weight)

Then we can create a class that does the necessary operations for KFAC on a MyLinear Module.

The backpropagation is handled automatically by giving the MyLinearDerivatives object and the remaining bit is to extract the kronecker factors for the weights

class KFACMyLinear(HBPBaseModule):
    def __init__(self):
        super().__init__(derivatives=MyLinearDerivatives(), params=["weight"])

    def weight(self, ext, module, g_inp, g_out, backproped):
        if module.input0.dim() != 2:
            raise NotImplementedError(
                f"Only 2d inputs are supported by {ext.__class__.__name__} "
                + f"(got {module.input0.dim()})."
            )

        N = module.input0.size(0)
        flat_input = module.input0.reshape(N, -1)

        factor_from_sqrt = einsum("vni,vnj->ij", (backproped, backproped))
        factor_from_input = einsum("ni,nj->ij", (flat_input, flat_input)) / N
        kron_factors = [factor_from_sqrt, factor_from_input]

        return kron_factors

Register the KFAC-MyLinear module in the KFAC extension

# register module-computation mapping
extended_kfac = KFAC()
extended_kfac.set_module_extension(MyLinear, KFACMyLinear())

Use KFAC with the new extension

batch_size = 10
hidden_size = 2
input_size = 3

inputs = torch.randn(batch_size, input_size, device=device)
targets = torch.randn((batch_size, 1), device=device)

model = torch.nn.Sequential(
    MyLinear(input_size, hidden_size),
    MyLinear(hidden_size, 1),
)
lossfunc = torch.nn.MSELoss(reduction="mean")

extend(model)
extend(lossfunc)

with backpack(extended_kfac):
    loss = lossfunc(model(inputs), targets)
    loss.backward()

for p in model.parameters():
    print("Parameter of shape ", p.shape)
    print("grad", p.grad)
    print("kronecker factors", p.kfac)
zzpustc commented 1 year ago

Thanks for your guidance! :)

zzpustc commented 1 year ago

If I leverage KFAC as above many times (e.g., repeat it each epoch), it would be possible to get the error "RuntimeError: The size of tensor a (64) must match the size of tensor b (10) at non-singleton dimension 2" when executing loss.backward(). Can you provide some insights? Here are the error messages:

2023-01-04 17:07:04.638 loss.backward()

2023-01-04 17:07:04.638 File "/usr/local/python/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward

2023-01-04 17:07:04.638 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

2023-01-04 17:07:04.638 File "/usr/local/python/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward

2023-01-04 17:07:04.638 Variable._execution_engine.run_backward(

2023-01-04 17:07:04.638 File "/usr/local/python/lib/python3.8/site-packages/torch/utils/hooks.py", line 110, in hook

2023-01-04 17:07:04.638 res = user_hook(self.module, grad_input, self.grad_outputs)

2023-01-04 17:07:04.638 File "/usr/local/python/lib/python3.8/site-packages/backpack/__init__.py", line 209, in hook_run_extensions

2023-01-04 17:07:04.638 backpack_extension(module, g_inp, g_out)

2023-01-04 17:07:04.638 File "/usr/local/python/lib/python3.8/site-packages/backpack/extensions/backprop_extension.py", line 127, in __call__

2023-01-04 17:07:04.638 module_extension(self, module, g_inp, g_out)

2023-01-04 17:07:04.638 File "/usr/local/python/lib/python3.8/site-packages/backpack/extensions/module_extension.py", line 123, in __call__

2023-01-04 17:07:04.638 self.__save_backproped_quantity(extension, module_inp, bp_quantity)

2023-01-04 17:07:04.638 File "/usr/local/python/lib/python3.8/site-packages/backpack/extensions/module_extension.py", line 202, in __save_backproped_quantity

2023-01-04 17:07:04.638 extension.saved_quantities.save_quantity(

2023-01-04 17:07:04.638 File "/usr/local/python/lib/python3.8/site-packages/backpack/extensions/saved_quantities.py", line 31, in save_quantity

2023-01-04 17:07:04.638 save_value = accumulation_function(existing, quantity)

2023-01-04 17:07:04.638 File "/usr/local/python/lib/python3.8/site-packages/backpack/extensions/secondorder/hbp/__init__.py", line 138, in accumulate_backpropagated_quantities

2023-01-04 17:07:04.638 return existing + other

2023-01-04 17:07:04.638 RuntimeError: The size of tensor a (64) must match the size of tensor b (10) at non-singleton dimension 2
fKunstner commented 1 year ago

Not sure.

Can you isolate the specific batch that is causing the issue? There shouldn't be a dependency across batch evaluations, so I'm not sure "running it multiple time" is the problem.

It might be useful to set the debug flag on with backpack(MyKFAC, debug=True): (docs) to print what the backward pass does.

zzpustc commented 1 year ago

If I fix the seed and test on different environments (Linux terminal, Docker container, etc.), this error happens at different time points; if I test on a single environment, then it always raises the error at the same time point no matter how I change the seed, which indicates that this error has no relation with the data but correlates to intrinsic environment-related implementation. Here is the error message under debug mode:

2023-01-05 19:56:36.780 [DEBUG] Running extension <backpack.extensions.secondorder.hbp.KFAC object at 0x7f64f8342130> on CrossEntropyLoss()
2023-01-05 19:56:36.780 [DEBUG] Running extension hook on CrossEntropyLoss()
2023-01-05 19:56:36.780 [DEBUG] Running extension <backpack.extensions.secondorder.hbp.KFAC object at 0x7f64f8342130> on NormedLinear()
2023-01-05 19:56:36.780 [DEBUG] Running extension hook on NormedLinear()
2023-01-05 19:56:36.780 [DEBUG] Running extension <backpack.extensions.secondorder.hbp.KFAC object at 0x7f64f8342130> on CrossEntropyLoss()
2023-01-05 19:56:36.780 [DEBUG] Running extension hook on CrossEntropyLoss()
2023-01-05 19:56:36.780 [DEBUG] Running extension <backpack.extensions.secondorder.hbp.KFAC object at 0x7f64f8342130> on NormedLinear()
2023-01-05 19:56:36.780 [DEBUG] Running extension hook on NormedLinear()
2023-01-05 19:56:36.780 [DEBUG] Running extension <backpack.extensions.secondorder.hbp.KFAC object at 0x7f64f8342130> on CrossEntropyLoss()
2023-01-05 19:56:36.781 File "/usr/local/python/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
2023-01-05 19:56:36.781 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
2023-01-05 19:56:36.781 File "/usr/local/python/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward
2023-01-05 19:56:36.781 Variable._execution_engine.run_backward(
2023-01-05 19:56:36.781 File "/usr/local/python/lib/python3.8/site-packages/torch/utils/hooks.py", line 110, in hook
2023-01-05 19:56:36.781 res = user_hook(self.module, grad_input, self.grad_outputs)
2023-01-05 19:56:36.781 File "/usr/local/python/lib/python3.8/site-packages/backpack/__init__.py", line 209, in hook_run_extensions
2023-01-05 19:56:36.781 backpack_extension(module, g_inp, g_out)
2023-01-05 19:56:36.781 File "/usr/local/python/lib/python3.8/site-packages/backpack/extensions/backprop_extension.py", line 127, in __call__
2023-01-05 19:56:36.781 module_extension(self, module, g_inp, g_out)
2023-01-05 19:56:36.781 File "/usr/local/python/lib/python3.8/site-packages/backpack/extensions/module_extension.py", line 123, in __call__
2023-01-05 19:56:36.781 self.__save_backproped_quantity(extension, module_inp, bp_quantity)
2023-01-05 19:56:36.781 File "/usr/local/python/lib/python3.8/site-packages/backpack/extensions/module_extension.py", line 202, in __save_backproped_quantity
2023-01-05 19:56:36.781 extension.saved_quantities.save_quantity(
2023-01-05 19:56:36.781 File "/usr/local/python/lib/python3.8/site-packages/backpack/extensions/saved_quantities.py", line 31, in save_quantity
2023-01-05 19:56:36.781 save_value = accumulation_function(existing, quantity)
2023-01-05 19:56:36.781 File "/usr/local/python/lib/python3.8/site-packages/backpack/extensions/secondorder/hbp/__init__.py", line 138, in accumulate_backpropagated_quantities
2023-01-05 19:56:36.781 return existing + other
2023-01-05 19:56:36.781 RuntimeError: The size of tensor a (64) must match the size of tensor b (10) at non-singleton dimension 2
fKunstner commented 1 year ago

Not obvious from the logs without the code. It's weird that the error appears after the call on CrossEntropyLoss(). This should be the start of a new backward pass and there shouldn't be anything to call accumulation_function on. Do you have a cross entropy loss in the middle of your network?

If you can isolate the issue in a minimal reproducible example I could try to give it a look.

zhangtj1996 commented 1 year ago

Can we write code for general losses that are not predefined in the torch?

For example,

def loss_func(x,y):
    return (x-y)**2

loss_func = extend(loss_func)

AttributeError: 'function' object has no attribute 'children'

f-dangel commented 3 months ago

Closing with a pointer to #325 which discusses how to support custom loss functions with second-order extensions.