deeplearning-wisc / gradnorm_ood

On the Importance of Gradients for Detecting Distributional Shifts in the Wild
Apache License 2.0
53 stars 7 forks source link

Question about code implementation of Gradnorm #2

Closed KacperKubara closed 2 years ago

KacperKubara commented 2 years ago

Hi,

I have a question about the implementation of the Gradnorm in this repository. Mainly, the following line: https://github.com/deeplearning-wisc/gradnorm_ood/blob/7c5805c3c64570569f96445c4a062f4e0aa2e58e/test_ood.py#L136

In the above equation, is dim=-1 equivalent to label dimension? If yes, I am bit confused how it works. Looking at this explanation from the paper: image

In the code equation, you do torch.sum over label dimension. So from size [b_size, num_classes], you get [b_size]. Then the torch.mean() averages over the batch size, not over the label size as in the paper (so the equation has 1/b, instead of 1/C as in the paper).

Is my reasoning correct here, or did I make a mistake somewhere?

EDIT: In the paper you make this comment: . In practice, GradNorm can be conveniently implemented by calculating the cross-entropy loss between the predicted softmax probability and a uniform vector as the target.. Perhaps the above is related to this?

iurgnauh commented 2 years ago

Hi,

This is a great question! One important thing here is for GradNorm, the batch size will always be 1, because we need parameter gradients for each sample instead of an aggregated gradient per batch. See the following lines: https://github.com/deeplearning-wisc/gradnorm_ood/blob/7c5805c3c64570569f96445c4a062f4e0aa2e58e/test_ood.py#L220-L221

Therefore, in torch.sum, the shape will always be [1, num_classes]. This line is a little bit redundant, because torch.mean actually does nothing. I will update the code later. We can ignore 1/C because this is the same for every sample.

Hope this answers your question.

KacperKubara commented 2 years ago

Ah got it, thanks!