mmasana / FACIL

Framework for Analysis of Class-Incremental Learning with 12 state-of-the-art methods and 3 baselines.
https://arxiv.org/pdf/2010.15277.pdf
MIT License
524 stars 99 forks source link

The question of whether attention distillation loss in LwM can produce gradient. #9

Closed NUAA-XSF closed 2 years ago

NUAA-XSF commented 3 years ago

Hello! Thank you for your nice work. I have a question: LwM (learning without Memorizing) paper uses attention distillation loss. In your code (lwm.py):

# in class GradCAM
def __call__(self, input, class_indices=None, return_outputs=False):
        # pass input & backpropagate for selected class
        if input.dim() == 3:
            input = input.view([1] + list(input.size()))
        self.model.eval()
        model_output = self.model(input)
        logits = torch.cat(model_output, dim=1)
        if class_indices is None:
            class_indices = logits.argmax(dim=1)
        score = logits[:, class_indices].squeeze()
        self.model.zero_grad()
        score.mean().backward(retain_graph=self.retain_graph)
        model_output = [o.detach() for o in model_output]

        # create map based on gradients and activations
        with torch.no_grad():
            weights = F.adaptive_avg_pool2d(self.gradients, 1)
            att_map = (weights * self.activations).sum(dim=1, keepdim=True)
            att_map = F.relu(att_map)
            del self.activations
            del self.gradients
            return (att_map, model_output) if return_outputs else att_map

I feel that using such a code does not seem to produce gradients when backpropagating. Looking forward to your reply. Thank you.

mmasana commented 3 years ago

Hi @NUAA-XSF, happy you like our work! The gradients are saved when doing the backward pass (line 229 of lwm.py):

def __enter__(self):
        # register hooks to collect activations and gradients
        def forward_hook(module, input, output):
            self.activations = output.detach()

        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()

        # hook to final layer
        self.fhandle = self.model_layer.register_forward_hook(forward_hook)
        self.bhandle = self.model_layer.register_backward_hook(backward_hook)
        return self

and used in the line weights = F.adaptive_avg_pool2d(self.gradients, 1) of the part of the code you posted. Was this what you were asking? Or did you mean something else?

NUAA-XSF commented 3 years ago

@mmasana Thank you for your reply .

def backward_hook(module, grad_input, grad_output):
    self.gradients = grad_output[0].detach()

This gradient is used to generate the Attention map .I have understood. What I really want to ask is when using two attention maps from Mt and M{t-1} to generate attention distillation loss, can gradients be generated when this loss is propagated back. I can't find the relevant code.

# retain_graph = False in your code (line 251 in lwm.py)
# I think retain_graph should be True, because  attention distillation loss
# will be calculated later
score.mean().backward(retain_graph=self.retain_graph)
NUAA-XSF commented 3 years ago

@mmasana There are two other things that confuse me (line 247 in lwm.py):

if class_indices is None:
    class_indices = logits.argmax(dim=1)
score = logits[:, class_indices].squeeze()
self.model.zero_grad()
score.mean().backward(retain_graph=self.retain_graph)

I think there is a problem with the third line of code, for example

logits = torch.tensor([[1,2,3,4],[8,7,6,5]]).float()
class_indices = logits.argmax(dim=1)  # class_indices:tensor([3, 0])
score = logits[:, class_indices].squeeze() # score:tensor([[4., 1.], [5., 8.]])

The result is confusing because it contains other values that are not the maximum. Another thing I don’t understand is why use score.mean()

mmasana commented 3 years ago

@NUAA-XSF thanks for pointing this out. What you mention on the third line of code looks strange indeed, so I will check it out. I have not looked to this code in some time, but I remember we based the implementation on this other repository. This paper was a bit tricky already because of the lack of code and the hyperparameters of the attention-distillation loss (γ in the original paper) not being disclosed.

The result is confusing because it contains other values that are not the maximum.

It seems it returns the values of the maximum for each entry, instead of applying it element-wise.

Another thing I don’t understand is why use score.mean()

This one should be just the averaging of the batch before doing the backward pass. If it was a sum, when the batch-size is smaller (i.e last batch of an epoch) then the backpropagation is done at a smaller scale in comparison to a "full" batch. However, since there seems to be more values than it should in the score tensor, it might be not necessary if that part is modified.

Finally, another user mentioned that in the distillation loss we should use torch.nn.functional.normalize instead of torch.norm. We will update it to be:

def attention_distillation_loss(self, attention_map1, attention_map2):
    attention_map1 = torch.nn.functional.normalize(attention_map1.view(attention_map1.size(0),-1), p=2, dim=1, eps=1e-12, out=None)
    attention_map2 = torch.nn.functional.normalize(attention_map2.view(attention_map1.size(0),-1), p=2, dim=1, eps=1e-12, out=None)
    return torch.norm( attention_map2 - attention_map1, p=1, dim=1).mean()

I thought it would also be good to mention it in case is useful for you. Let me know if you have any other issue.