KaiyangZhou / mixstyle-release

Domain Generalization with MixStyle (ICLR'21)
MIT License
268 stars 39 forks source link

'Cross-domain' mode of the MixStyle module may malfunction when using nn.DataParallel. #19

Open michaelssf opened 1 year ago

michaelssf commented 1 year ago

The 'cross-domain' mode of the MixStyle module may malfunction when using nn.DataParallel, as samples from the two different source domains will be separately scattered into different GPUs. The orginal implementation codes of 'cross-domain' version of MixStyle can be found at [1]. When the network is trained on 2 GPUs with nn.DataParallel, the first half of input samples from a specific source domain will be sent to GPU:0 while the rest to GPU:1, due to the 'scatter'[2] of inputs. The model repilca in GPU:0 or GPU:1 will only have access to a half of the orginal batch whose samples are drawn from exactly the same source domain, thus invalidating the expected mixture across different domains. So if you want to accurately reproduce the results of the original paper or transfer the MixStyle codes to your own method, please use a single GPU to train the model, or modify the codes by yourself to suit the setting of DataParallel. I have discussed this issue with the author and made experiments to verify the malfunctioning. When trained with 4 GPUs using nn.DataParallel, the average accuracy on PACS benchmark drops to 78.0%, which is close to result of the oracle version without inserting MixStyle modules.

[1] https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/modeling/ops/mixstyle.py#L109 [2] https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html