facebookresearch / DomainBed

DomainBed is a suite to test domain generalization algorithms
MIT License
1.42k stars 299 forks source link

Wrong batch size being used #7

Closed SirRob1997 closed 4 years ago

SirRob1997 commented 4 years ago

If I run the ERM on PACS via the following command:

python -m domainbed.scripts.train --data_dir=datasets/ --algorithm ERM --dataset PACS --hparams='{"resnet18": "True"}'

and print the all_x and all_y in the ERM algorithm like:

 def update(self, minibatches):
        all_x = torch.cat([x for x,y in minibatches])
        all_y = torch.cat([y for x,y in minibatches])
        print(all_x.shape)
        print(all_y.shape)

I get this output for the parameters and the sizes:

Args:
        algorithm: ERM
        checkpoint_freq: None
        data_dir: ../../projects/DomainBed/datasets/
        dataset: PACS
        holdout_fraction: 0.2
        hparams: {"resnet18": "True"}
        hparams_seed: 0
        output_dir: train_output
        seed: 0
        skip_model_save: False
        steps: None
        test_envs: [0]
        trial_seed: 0
HParams:
        batch_size: 32
        class_balanced: False
        data_augmentation: True
        groupdro_eta: 0.01
        irm_lambda: 100.0
        irm_penalty_anneal_iters: 500
        lr: 5e-05
        mixup_alpha: 0.2
        mldg_beta: 1.0
        mlp_depth: 3
        mlp_dropout: 0.0
        mlp_width: 256
        mmd_gamma: 1.0
        mtl_ema: 0.99
        resnet18: True
        resnet_dropout: 0.0
        sag_w_adv: 0.1
        weight_decay: 0.0
torch.Size([96, 3, 224, 224])
torch.Size([96])

Since all_x and all_y represent batches of shape batch size x C x H x W shouldn't batch_size = 32 instead of the batch size suggested by the shape batch_size = 96 since it is stated with 32 as hyperparameter?

Maybe I missed something since this is exactly multiplied by factor 3.

lopezpaz commented 4 years ago

args.batch_size is the batch size per training environment (in your example, a batch of 32 examples for each of the 3 training environments amounts to a total of 96 examples in all_x).