f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
555 stars 55 forks source link

KFAC support in BatchNorm (eval mode) #259

Open pyun-ram opened 2 years ago

pyun-ram commented 2 years ago

Hi,

Thanks for the repo! This is really a nice work. I am planning to calculate the KFAC with backpack. But it raises the following error:

NotImplementedError: Extension saving to kfac does not have an extension for Module <class 'torch.nn.modules.batchnorm.BatchNorm2d'>

My network is as follows:

model = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=3),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Conv2d(8, 4, 3, stride=3),
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(36, 10))
loss = nn.CrossEntropyLoss()

When calculating the KFAC with:

    model_ = extend(model.eval())
    logits = model_(X)
    loss = extend(loss_func)(logits, Y)
    with backpack(KFAC(mc_samples=1000)):
        loss.backward()

It raises the not implemented error. I am wondering whether calculating KFAC in a network with BN layers in the middle is supported by backpack? It seems like it should be supported, since it successfully works in ResNet.

Thanks

f-dangel commented 2 years ago

Hi, thanks for your question!

Just to make sure I'm getting it right: You want to compute KFAC for the Conv2d and Linear layers in your network, or do you want to compute KFAC for the parameters of the BatchNorm2d layer? (For the latter, I'm not sure if KFAC is defined)

pyun-ram commented 2 years ago

Indeed, I want to compute the KFAC for the parameters of Conv2d, Linear and BatchNorm2d layer. Is it possible to achieve this?

f-dangel commented 2 years ago

BackPACK can compute KFAC for Linear and Conv2d layers, but not for BatchNorm2d. I don't know how the KFAC papers deal with batch normalization. Do you know? If so, one could implement this missing feature

Sadly, there is no easy way to tell BackPACK to ignore the parameters of batch norm layers, because it tries to compute its quantities on all parameters that have requires_grad=True.

If you want to get KFAC for the supported layers, you will have to set p.requires_grad=False for the BN parameters. But then you also won't get their gradient.

pyun-ram commented 2 years ago

Thanks for the prompt reply! Yep, I have not seen some paper discussing the KFAC calculation for BN either...

To get the KFAC for the supported layers (Linear and Conv2d), I found a way to bypass the NotImplementedError.

# Extend hbp/__init__.py by
class HBP(SecondOrderBackpropExtension):
    def __init__(
        self,
        curv_type,
        loss_hessian_strategy,
        backprop_strategy,
        ea_strategy,
        savefield="hbp",
    ):
        ...
        super().__init__(
            savefield=savefield,
            fail_mode="ERROR",
            module_exts={
                ...
                Identity: custom_module.HBPScaleModule(),
                BatchNorm2d: batchnorm_nd.HBPBatchNormNd(),
            },
        )

# The HBPBatchNormNd is defined as
from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives
from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule

class HBPBatchNormNd(HBPBaseModule):
    def __init__(self):
        super().__init__(BatchNormNdDerivatives(), params=None)

With such modification, it works without raising the error. I am not quite sure whether it is the right manner. Do you have any advice?

f-dangel commented 2 years ago

Hi,

that workaround looks good! Indeed, this will ignore the BN parameters, while keeping BackPACK's backpropagation through the layer for KFAC intact.

pyun-ram commented 2 years ago

That's a great relief! :)

f-dangel commented 2 years ago

One way to get started on this would be to add support for KFAC in BatchNorm in evaluation mode.

I will outline in the following what needs to be done (this may be technically not 100% accurate).

Pull requests welcome.


Let's assume a BatchNorm1d layer that takes an input X of shape [N, C] and maps it to an output Z of shape [N, C]. The parameters γ and β are both of shape [C]. The forward pass (in evaluation mode) is

Z[n, :] = γ ⊙ X[n, :] + β        n = 1, ... , N

(where is elementwise multiplication). This looks a bit like a Linear layer with weights W = diag(γ) and bias b = β.

We don't really need a Kronecker factorization here, because the curvature blocks for γ and β are both of shape [C, C]. So instead we compute the MC-sampled Fisher/GGN block:

pyun-ram commented 2 years ago

Thanks for your guidance! :) A pull request has been raised. A simple test case is also added to test the result and mode checking. I want to ask one more question why it is not needed to divide the kfac_gamma by JTv.shape[0], which is the number of MC samples in calculating kfac_gamma?

kfac_gamma = einsum("mnc,mnd->cd", JTv, JTv) # [C, C]
f-dangel commented 1 year ago

Hi,

thanks for the PR; apologies you will have to be patient with my review.

Regarding your question: Good point! The factor 1 / sqrt(C) where C is the number of MC samples is inserted by the loss function, which creates the MC-approximated Hessian square root that is then backpropagated through all layers. Squaring that results in the desired 1 / C.

For CrossEntropyLoss this happens here in the code (M denotes the number of MC samples).

Best, Felix

fredguth commented 3 months ago

Is there a way to Backpack ignore the modules it does not support? I want to use it with models I did not implement my self (timm models, for example).

f-dangel commented 3 months ago

Hi,

are you asking this question w.r.t. KFAC? If you want to use a first-order extension, you can simply extend the layers that are supported by BackPACK. If you want to use a second-order extension, all layers must be supported by BackPACK, as otherwise it cannot backpropagate the additional information through the compute graph.

Best, Felix

fredguth commented 2 months ago

I mean second-order extension. I want just to have an estimation and I don't care if the value is not precise, I just wanted to check its change during training. Can I simply remove the module if it does not change the dimensions?

f-dangel commented 2 months ago

Hey,

not sure if I'm following what you exactly want to do. If you remove the BN layers and all layers are supported by BackPACK, you can use second-order extensions. But also, your network will behave differently because you eliminated the BN layers.

Best, Felix

fredguth commented 2 months ago

I only use BackPack second-order extension to measure the Fisher Information of the weights during training, but I will not use them in the training. Every n steps before I do the next training step, I use hessian to measure the information, save the result and, clean the gradients for the next step. It is not a problem for me not measuring the information in the batchnorm layers.

fKunstner commented 2 months ago

If I got your setup right, that will still be difficult without implementing the batchnorm operation. There is no option to disable the extension on the batchnorm parameters only, because backpack still needs to backpropagate the second-order information through the batchnorm layer to compute the information for the parameters of the earlier layers.

Here's a workaround that could work without having to code the batchnorm extension.

Say we start with the network

net1 = Sequential(
  Linear(1,1),
  Batchnorm1d(1),
  Linear(1,1),
)

We can make a second network that does the same operation as net1 (if net1 is in eval mode) using only Linear layers,

net2 = Sequential(
  Linear(1,1), 
  Linear(1,1), 
  Linear(1,1), 
  Linear(1,1),
)

To make them the same, we need to map the weights from net1 to net2. For the linear layers, we just copy the data

net2[0].weight.data = net1[0].weight.data
net2[0].bias.data = net1[0].bias.data
net2[3].weight.data = net1[2].weight.data
net2[3].bias.data = net1[2].bias.data

And we should be able to implement the batchnorm operation with 2 linear layers by remapping them as follows (needs a double check)

# Implement the normalization 
# x -> (x - running_mean) / sqrt(running_var + eps) = (1 / sqrt(running_var + eps)) * x - running_mean / sqrt(running_var + eps)

bnlayer = net1[1]

net2[1].weight.data = 1/torch.sqrt(bnlayer._buffers["running_var"].data + bnlayer.eps)
net2[1].bias.data = bnlayer._buffers["running_mean"].data * net2[1].weight.data 
net2[2].weight.data = bnlayer._parameters["weight"].data
net2[2].bias.data = bnlayer._parameters["bias"].data

Now we can extend net2 using backpack to compute kfac.

Instead of doing

extend(net1)

...

with Backpack("KFAC"):
    loss(net1).backward()
extend(net2)

...

map_weights(net1, net2)
with Backpack("KFAC"):
    loss(net2).backward()
inverse_map_grad_and_kfac(net2, net1)

where inverse_map_grad_and_kfac would map the _.grad and _.kfac attributes of the parameters of net2 to the right parameters of net1

(Although operations on _.data shouldn't be tracked by autodiff, maybe put all this in a torch.nograd() block to make sure gradients don't get propagated from one network to the rest?)

fredguth commented 2 months ago

THanks a lot for your thoughts! I will study this... :-)