kuangliu / pytorch-retinanet

RetinaNet in PyTorch
992 stars 250 forks source link

Batch size 1 crashes in datagen.py #7

Closed mubastan closed 7 years ago

mubastan commented 7 years ago

Lines

max_size, _ = torch.IntTensor([im.size() for im in imgs]).max(0)
max_h, max_w = max_size[1], max_size[2]
njtuzzy commented 7 years ago

yes, i have met the same problems, if bach size =1 then, the tensor.max(0) will be a tensor with 1*3 size so there may add a condition to mitigate it!

kuangliu commented 7 years ago

Yeah, I noticed that. I think that's due to the inconsistent behavior of PyTorch max function (see https://github.com/pytorch/pytorch/issues/2317).

>>torch.Tensor([[1,2,3]]).max(0)
(
  1  2  3
 [torch.FloatTensor of size 1x3], 
  0  0  0
 [torch.LongTensor of size 1x3])

>>torch.Tensor([[1,2,3],[3,4,5]]).max(0)
(
  3
  4
  5
 [torch.FloatTensor of size 3], 
  1
  1
  1
 [torch.LongTensor of size 3])

Anyway, I'll update the code.