GMvandeVen / continual-learning

PyTorch implementation of various methods for continual learning (XdG, EWC, SI, LwF, FROMP, DGR, BI-R, ER, A-GEM, iCaRL, Generative Classifier) in three different scenarios.
MIT License
1.54k stars 310 forks source link

one little confusion about the loss_fn_kd function #18

Closed libo-huang closed 2 years ago

libo-huang commented 2 years ago

Many thanks for your impressive project. Here I am a few confused about the .detach() in the below code, https://github.com/GMvandeVen/continual-learning/blob/a02db26d3b10754abdc4a549bdcde6af488c94e0/utils.py#L35

which is defined in https://github.com/GMvandeVen/continual-learning/blob/a02db26d3b10754abdc4a549bdcde6af488c94e0/utils.py#L18

Refer to the blog, PyTorch .detach() method , .detach() will take the targets_norm as one fixed parameter in the the KD_loss, and the backpropagation will not update the parameters along the targets_norm related branch.

However, in your another project, brain-inspired-replay, the same loss function, loss_fn_kd uses,

 targets_norm = torch.cat([targets_norm, zeros_to_add], dim=1)

as shown in line 29, in which no .detach() is attached.

Although the same results all these two types I have tested, I am still confused about how does the second type work?

GMvandeVen commented 2 years ago

Hi @HLBayes, I'm sorry for the very late reply! This difference between the two repositories is indeed confusing. In both repositories, the .detach() operation is actually not needed. As you indicate in your comment, the .detach() operation stops backpropagation as it resets any gradients being tracked. However, in both repositories, the target_norm variable already did not have any gradients being tracked, as that variable was computed using with torch.no_grad(): (as for example here: link). Sorry for the confusion!