KaiyangZhou / Dassl.pytorch

A PyTorch toolbox for domain generalization, domain adaptation and semi-supervised learning.
MIT License
1.21k stars 169 forks source link

error occuring odd batch size in crossdomain mode #33

Closed sansyo closed 2 years ago

sansyo commented 2 years ago

I faced batch size incompatible error when the batch size is odd(e.g. 31,71).

so, I suggest the rand perm size to B to each perm batch size.

before

B=31 perm = torch.randn((B,3,9,9)) print(perm.shape) perm_b, perm_a = perm.chunk(2) perm_b = perm_b[torch.randperm(B // 2)] perm_a = perm_a[torch.randperm(B // 2)] perm = torch.cat([perm_b, perm_a], 0) print(perm.shape)

before output

torch.Size([31, 3, 9, 9]) torch.Size([30, 3, 9, 9])

after

perm = torch.randn((31,3,9,9)) print(perm.shape) perm_b, perm_a = perm.chunk(2) perm_b = perm_b[torch.randperm(perm_b.shape[0])] perm_a = perm_a[torch.randperm(perm_a.shape[0])] perm = torch.cat([perm_b, perm_a], 0)

print(perm.shape)

after output

torch.Size([31, 3, 9, 9]) torch.Size([31, 3, 9, 9])