facebookresearch / DomainBed

DomainBed is a suite to test domain generalization algorithms
MIT License
1.42k stars 298 forks source link

Possible mistake in `AbstractDANN`? #101

Closed cwognum closed 2 years ago

cwognum commented 2 years ago

I was going through the DANN implementation and there's a couple of things that seem off to me. My confusion mostly relates to lines 266-272:

disc_loss = F.cross_entropy(disc_out, disc_labels)
disc_softmax = F.softmax(disc_out, dim=1)
input_grad = autograd.grad(disc_softmax[:, disc_labels].sum(), [disc_input], create_graph=True)[0]
grad_penalty = (input_grad**2).sum(dim=1).mean(dim=0)
disc_loss += self.hparams['grad_penalty'] * grad_penalty

The following things raise questions and I don't see why it's implemented this way:

Something like this makes more sense to me (pseudo-code just to illustrate the idea):

disc_loss = F.cross_entropy(disc_out, disc_labels)

if discriminator_step:
    return disc_loss
else:
    input_grad = autograd.grad(disc_loss, [disc_input], create_graph=True)[0]
    grad_penalty = (input_grad**2).sum(dim=1).mean(dim=0)
    return classifier_loss + self.hparams['lambda'] * -grad_penalty
cwognum commented 2 years ago

I discussed it with my supervisor and while I can't quite wrap my head around it yet, I think it's ok! Sorry for the distraction.