mmasana / FACIL

Framework for Analysis of Class-Incremental Learning with 12 state-of-the-art methods and 3 baselines.
https://arxiv.org/pdf/2010.15277.pdf
MIT License
512 stars 98 forks source link

LwM - no gradient in attention distillation loss #37

Open fszatkowski opened 1 year ago

fszatkowski commented 1 year ago

Hi, when experimenting with LwM in FACIL I noticed that the method behaves the same regardless of the choice of gamma parameter that controls attention distillation loss. Upon closer investigation, I noticed that during training attention maps returned by GradCAM have no grad, as you can check yourself with the debugger in this line: https://github.com/mmasana/FACIL/blob/e9d816c0c649db91bde1568300a8ba3045651ffd/src/approach/lwm.py#L126 When we later use attention maps to compute attention distillation loss this loss has no gradient and it's contribution to the gradient update is ignored. Therefore, LwM in FACIL basically does LwF with extra unused computation.

I think the issue is in class GradCAM in line 226, where the activations are detached, and later in line 255 which disables gradients when computing attention maps. I think this class should have the option to preserve gradients when computing attention maps and trigger this option for a forward pass of the current net. Then the attention maps for current net will have requires_grad=Trueand consequently attention loss of will contribute to weight updates.

mmasana commented 1 year ago

Hi @fszatkowski,

I remember discussing about this approach and the gradients/loss before, so maybe we missed something. A first change that has not been pushed yet into main is the one found in this commit. And we also had some discussion about it in this issue.

Could you check if any of those help/tackle the issue? My first guess is that maybe one does not want to propagate the gradients through the gradcam, but instead generate a loss with the attention maps that is backpropagated through the resnet model (the same weights as the CE-loss modifies). However, it's been a while, so if these links do not help, let me know and we'll see if we can dive into it again and fix it.

When checking the experiments, LwM is better than LwF by a significant margin. That seems to indicate some difference between both methods is indeed happening. But it could also be some other reason. Or maybe we introduced the error when cleaning the code for public release.

Let me know if that helped. I'm quite interested in solving this if it is indeed an issue!

fszatkowski commented 1 year ago

The commit you linked changes torch.norm to torch.nn.functional.normalize, but I don't think it helps with the main issue being no gradient in the attention loss. I think since the hooks in GradCAM detach both gradients and activations, the attention maps computed later on have no gradients and there is no way to backpropagate anything through the loss that is computed based on attention maps.

Here you use detach on both gradients and activations: https://github.com/mmasana/FACIL/blob/e9d816c0c649db91bde1568300a8ba3045651ffd/src/approach/lwm.py#LL223C1-L229C53 So when GradCAM is called, it returns the attention maps that contain no gradients: https://github.com/mmasana/FACIL/blob/e9d816c0c649db91bde1568300a8ba3045651ffd/src/approach/lwm.py#L255-L261 I think attention loss function is okay, but since it's computed on two variables without gradients it simply adds a scalar with 0 derivative to loss and when you call loss.backward in training loop the part that comes from attention loss doesn't backpropagate.

I work on slightly modified fork of FACIL so I might be having different results than the version in this repository. Could you please run the LwM code twice with different coefficients for attention map loss (for example 0 and some other value). I think you will get the same final results regardless of the attention loss coefficient.

mmasana commented 1 year ago

I see, the way .detach() is called, could indeed block the gradients from updating. I'll first try to reproduce what you propose with the --gamma parameter to check it out.

mmasana commented 1 year ago

You are correct, it seems like that loss is not having an effect indeed. There are no gradients updated, and therefore changing the parameter has no effect and brings the method towards LwF. I'll need to check some of the older dev branches to see when did we introduce the bug (or forgot to update the method with the fix), since the older spreadsheet files from the original experiments do show a difference when changing the gamma.

Thanks for the help! If you happen to already have a hotfix for the issue, please do propose it to speed things up.

fszatkowski commented 1 year ago

I simply tried removing activations.detach() call from hooks and making torch.no_grad()in GradCAM pass conditional:

class GradCAM:
    ...
    def __enter__(self):
        # register hooks to collect activations and gradients
        def forward_hook(module, input, output):
            if self.retain_graph:
                self.activations = output
            else:
                self.activations = output.detach()

        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
    ...

    def __call__(self, input, class_indices=None, return_outputs=False, adapt_bn=False):
        ...
        with torch.no_grad() if not self.retain_graph else contextlib.suppress():
            weights = F.adaptive_avg_pool2d(self.gradients, 1)
            att_map = (weights * self.activations).sum(dim=1, keepdim=True)
            att_map = F.relu(att_map)
            del self.activations
            del self.gradients
            return (att_map, model_output) if return_outputs else att_map

Then in the training loop:

                ...
                attmap_old, outputs_old = gradcam_old(images, return_outputs=True)
                with GradCAM(self.model, self.gradcam_layer, retain_graph=True) as gradcam:
                    attmap = gradcam(images)  # this use eval() pass
                ...

But the results I got with it were far from the scores from your paper, so I think this is still not working as supposed.