Closed DevinCheung closed 2 years ago
To find out the cause, I git clone the latest code and rerun ERM and CORAL on TerraIncognita with "n_hparams=20", "n_trials=3". The code remains unchanged. Performance with the model selection method of "training-domain validation set" is as follows:
-------- Dataset: TerraIncognita, model selection method: training-domain validation set
Algorithm L100 L38 L43 L46 Avg
ERM 51.9 +/- 1.4 42.9 +/- 1.4 56.3 +/- 0.5 37.0 +/- 0.8 47.0
CORAL 52.1 +/- 2.6 41.9 +/- 2.2 56.2 +/- 0.5 37.1 +/- 1.7 46.8
-------- Averages, model selection method: training-domain validation set
Algorithm TerraIncognita Avg
ERM 47.0 +/- 0.7 47.0
CORAL 46.8 +/- 0.8 46.8
The reproduced results are much different from the results in DomainBed paper. I am not sure where the problem is. Thanks for helping me with this!
An example of implementations is as follows: Environment: Python: 3.6.13 PyTorch: 1.10.1 Torchvision: 0.11.2 CUDA: 11.3 CUDNN: 8200 NumPy: 1.19.2 PIL: 8.3.1 Args: algorithm: CORAL checkpoint_freq: None data_dir: ../../../datasets dataset: TerraIncognita holdout_fraction: 0.2 hparams: None hparams_seed: 13 output_dir: TI_CORAL_output/0c34459f16b9139dd5cd1551d848107f save_model_every_checkpoint: False seed: 641058751 skip_model_save: False steps: None task: domain_generalization test_envs: [2] trial_seed: 0 uda_holdout_fraction: 0 HParams: batch_size: 40 class_balanced: False data_augmentation: True lr: 0.0001653813153854724 mmd_gamma: 0.6095584318401025 nonlinear_classifier: False resnet18: False resnet_dropout: 0.5 weight_decay: 2.7643974709171963e-05
Unfortunately there are many factors at play that one should consider, down to the hardware specifications, pytorch version (as well as other libraries), irreducible indeterminism in some operations, multiprocessing, and so on. We are trying to secure the resources to do an entire sweep ourselves, but we cannot promise that amount of compute by now.
Hi, when I reproduce ERM and CORAL, the results show much larger fluctuation, i.e. much larger "std" when running "collect_results.py". Then mean ACC is also a little lower than claimed, especially on PACS, TerraIncognita, and VLCS. What could be the reason for this? Thanks!