facebookresearch / DomainBed

DomainBed is a suite to test domain generalization algorithms
MIT License
1.4k stars 299 forks source link

Error in discriminator regularisation in `AbstractDANN` #105

Closed prockenschaub closed 2 years ago

prockenschaub commented 2 years ago

While working my way through the AbstractDANN class, I think I came across a mistake in the calculation of the domain discriminator's loss.

https://github.com/facebookresearch/DomainBed/blob/2ed9edf781fe4b336c2fb6ffe7ca8a7c6f994422/domainbed/algorithms.py#L270-L274

If I understand the code correctly, here we want to regulariser the discriminator with the square of the gradients that go to the featurizer. The above code therefore aims to get each sample's predicted probabilities per domain (line 270), get the cross-entropy loss and calculate the gradient with respect to the featurizer output (line 271), square and average the gradient (line 273) and finally adds it to the loss (line 274).

Problem

In line 271, in order to get the probability for the true domain, we index the softmax like this disc_softmax[:, disc_labels] with (I think) the intention to get disc_softmax[i, 0] if i's domain is 0, disc_softmax[i, 1] if i's domain is 1, etc. Unfortunately, this is not what happens. Running this line e.g. on ColoredMNIST with batch size of 64 gives:

disc_softmax[:, disc_labels]
#> tensor([[0.5825, 0.5825, 0.5825,  ..., 0.4175, 0.4175, 0.4175],
#>        [0.5513, 0.5513, 0.5513,  ..., 0.4487, 0.4487, 0.4487],
#>        [0.5800, 0.5800, 0.5800,  ..., 0.4200, 0.4200, 0.4200],
#>        ...,
#>        [0.5869, 0.5869, 0.5869,  ..., 0.4131, 0.4131, 0.4131],
#>        [0.5305, 0.5305, 0.5305,  ..., 0.4695, 0.4695, 0.4695],
#>        [0.5332, 0.5332, 0.5332,  ..., 0.4668, 0.4668, 0.4668]],
#>       grad_fn=<IndexBackward0>)

disc_softmax[:, disc_labels].shape
#> torch.Size([128, 128])

That is, rather than indexing the last dimension, it first copies disc_softmax[:, 0] 64 times and then disc_softmax[:, 1] 64 times. It is actually a mistake I have made myself many many times, because it always intuitively feels like this should work.

Solution

Instead, I think what we want is

disc_softmax.gather(dim=-1, index=disc_labels[:, None]).shape
torch.Size([128, 1])

We probably also want to log before we sum to truly make it a cross-entropy but not sure if this makes much of a difference here. Overall, we could -- more transparently -- replace it with

F.cross_entropy(disc_out, disc_labels, reduction='sum')

(note that the sign is different to the original code but I think it doesn't matter because of the power 2 later).

Let me know if I got anything wrong here or if there's something I've missed :)