Closed NUAA-XSF closed 2 years ago
Hi @NUAA-XSF,
happy you like our work! The gradients are saved when doing the backward pass (line 229 of lwm.py
):
def __enter__(self):
# register hooks to collect activations and gradients
def forward_hook(module, input, output):
self.activations = output.detach()
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
# hook to final layer
self.fhandle = self.model_layer.register_forward_hook(forward_hook)
self.bhandle = self.model_layer.register_backward_hook(backward_hook)
return self
and used in the line weights = F.adaptive_avg_pool2d(self.gradients, 1)
of the part of the code you posted.
Was this what you were asking? Or did you mean something else?
@mmasana Thank you for your reply .
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
This gradient is used to generate the Attention map .I have understood. What I really want to ask is when using two attention maps from Mt and M{t-1} to generate attention distillation loss, can gradients be generated when this loss is propagated back. I can't find the relevant code.
# retain_graph = False in your code (line 251 in lwm.py)
# I think retain_graph should be True, because attention distillation loss
# will be calculated later
score.mean().backward(retain_graph=self.retain_graph)
@mmasana
There are two other things that confuse me (line 247 in lwm.py
):
if class_indices is None:
class_indices = logits.argmax(dim=1)
score = logits[:, class_indices].squeeze()
self.model.zero_grad()
score.mean().backward(retain_graph=self.retain_graph)
I think there is a problem with the third line of code, for example
logits = torch.tensor([[1,2,3,4],[8,7,6,5]]).float()
class_indices = logits.argmax(dim=1) # class_indices:tensor([3, 0])
score = logits[:, class_indices].squeeze() # score:tensor([[4., 1.], [5., 8.]])
The result is confusing because it contains other values that are not the maximum.
Another thing I don’t understand is why use score.mean()
@NUAA-XSF thanks for pointing this out. What you mention on the third line of code looks strange indeed, so I will check it out. I have not looked to this code in some time, but I remember we based the implementation on this other repository. This paper was a bit tricky already because of the lack of code and the hyperparameters of the attention-distillation loss (γ in the original paper) not being disclosed.
The result is confusing because it contains other values that are not the maximum.
It seems it returns the values of the maximum for each entry, instead of applying it element-wise.
Another thing I don’t understand is why use
score.mean()
This one should be just the averaging of the batch before doing the backward pass. If it was a sum, when the batch-size is smaller (i.e last batch of an epoch) then the backpropagation is done at a smaller scale in comparison to a "full" batch. However, since there seems to be more values than it should in the score
tensor, it might be not necessary if that part is modified.
Finally, another user mentioned that in the distillation loss we should use torch.nn.functional.normalize instead of torch.norm. We will update it to be:
def attention_distillation_loss(self, attention_map1, attention_map2):
attention_map1 = torch.nn.functional.normalize(attention_map1.view(attention_map1.size(0),-1), p=2, dim=1, eps=1e-12, out=None)
attention_map2 = torch.nn.functional.normalize(attention_map2.view(attention_map1.size(0),-1), p=2, dim=1, eps=1e-12, out=None)
return torch.norm( attention_map2 - attention_map1, p=1, dim=1).mean()
I thought it would also be good to mention it in case is useful for you. Let me know if you have any other issue.
Hello! Thank you for your nice work. I have a question: LwM (learning without Memorizing) paper uses attention distillation loss. In your code (lwm.py):
I feel that using such a code does not seem to produce gradients when backpropagating. Looking forward to your reply. Thank you.