facebookresearch / DomainBed

DomainBed is a suite to test domain generalization algorithms
MIT License
1.37k stars 294 forks source link

ColoredMNIST with IRM Implementation is different from the original IRM papers #112

Closed lccurious closed 1 year ago

lccurious commented 1 year ago

The ColoredMNIST with IRM implementation uses the following MLP network in original papers, which is very different from the Conv network implementation in DomainBed, and the results are also quite different (acc=66.9 on test_env[0.9] in original IRM paper; acc=10.1 on test_env[0.9] in this paper). Could you please let me know if this different implementation is reasonable for comparison?

class MLP(nn.Module):
  def __init__(self):
    super(MLP, self).__init__()
    if flags.grayscale_model:
      lin1 = nn.Linear(14 * 14, flags.hidden_dim)
    else:
      lin1 = nn.Linear(2 * 14 * 14, flags.hidden_dim)
    lin2 = nn.Linear(flags.hidden_dim, flags.hidden_dim)
    lin3 = nn.Linear(flags.hidden_dim, 1)
    for lin in [lin1, lin2, lin3]:
      nn.init.xavier_uniform_(lin.weight)
      nn.init.zeros_(lin.bias)
    self._main = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True), lin3)
  def forward(self, input):
    if flags.grayscale_model:
      out = input.view(input.shape[0], 2, 14 * 14).sum(dim=1)
    else:
      out = input.view(input.shape[0], 2 * 14 * 14)
    out = self._main(out)
    return out

The binary_cross_entropy_with_logits is used in the original IRM implementation; however, the general conv network is used in this repository. Please let me know if there is a reason for this. Under this implementation, it seems the IRM may not be suitable; thus, ColoredMNIST[0.9]'s benchmark score is 10.1 as opposed to 66.9 in the IRM paper.

trzhang0116 commented 1 year ago

Hi lccurious, I think using a ConvNet rather than an MLP for ColoredMNIST is reasonable since this dataset has image inputs. Another reason may be that since the same architecture should be used by all methods for fair benchmarking, it may not be appropriate if this architecture is itself tailored to a specific method (e.g., IRM). For your question on the discrepancy between test results, you can check Appendix B.1 of the DomainBed paper, which has indicated this discrepancy is largely due to different model selection methods rather than architectures.