Closed sansyo closed 2 years ago
I faced batch size incompatible error when the batch size is odd(e.g. 31,71).
so, I suggest the rand perm size to B to each perm batch size.
B=31 perm = torch.randn((B,3,9,9)) print(perm.shape) perm_b, perm_a = perm.chunk(2) perm_b = perm_b[torch.randperm(B // 2)] perm_a = perm_a[torch.randperm(B // 2)] perm = torch.cat([perm_b, perm_a], 0) print(perm.shape)
torch.Size([31, 3, 9, 9]) torch.Size([30, 3, 9, 9])
perm = torch.randn((31,3,9,9)) print(perm.shape) perm_b, perm_a = perm.chunk(2) perm_b = perm_b[torch.randperm(perm_b.shape[0])] perm_a = perm_a[torch.randperm(perm_a.shape[0])] perm = torch.cat([perm_b, perm_a], 0)
print(perm.shape)
torch.Size([31, 3, 9, 9]) torch.Size([31, 3, 9, 9])
I faced batch size incompatible error when the batch size is odd(e.g. 31,71).
so, I suggest the rand perm size to B to each perm batch size.
before
B=31 perm = torch.randn((B,3,9,9)) print(perm.shape) perm_b, perm_a = perm.chunk(2) perm_b = perm_b[torch.randperm(B // 2)] perm_a = perm_a[torch.randperm(B // 2)] perm = torch.cat([perm_b, perm_a], 0) print(perm.shape)
before output
torch.Size([31, 3, 9, 9]) torch.Size([30, 3, 9, 9])
after
perm = torch.randn((31,3,9,9)) print(perm.shape) perm_b, perm_a = perm.chunk(2) perm_b = perm_b[torch.randperm(perm_b.shape[0])] perm_a = perm_a[torch.randperm(perm_a.shape[0])] perm = torch.cat([perm_b, perm_a], 0)
print(perm.shape)
after output
torch.Size([31, 3, 9, 9]) torch.Size([31, 3, 9, 9])