kuangliu / pytorch-retinanet

RetinaNet in PyTorch
992 stars 250 forks source link

IndexError: too many indices for tensor of dimension 1 #65

Open sunshine-zkf opened 5 years ago

sunshine-zkf commented 5 years ago

I use the function decode in the encoder.py, always have the fellowing question.

where's the problem? I use the pytorch-0.4

gyfastas commented 5 years ago

When there is no bounding box detected, ids.nonzero.squeeze() automatically squeeze the tensor into nothing, which makes the ids with tensor.shape([]) try this in encoder.py, decode(): ids = score > CLS_THRESH if ids.numel() <=1: ids = ids.nonzero().squeeze(0) else: ids = ids.nonzero().squeeze()

sunshine-zkf commented 5 years ago

When there is no bounding box detected, ids.nonzero.squeeze() automatically squeeze the tensor into nothing, which makes the ids with tensor.shape([]) try this in encoder.py, decode(): ids = score > CLS_THRESH if ids.numel() <=1: ids = ids.nonzero().squeeze(0) else: ids = ids.nonzero().squeeze()

Thank you very much! I will try it!

KyuminHwang commented 5 years ago

@gyfastas Thank you for comment. ids = score > CLS_THRESH if ids.numel() <=1: ids = ids.nonzero().squeeze(0) else: ids = ids.nonzero().squeeze()

I also revised code, but still occur error. Could you give me another solution ?

gyfastas commented 5 years ago

@gyfastas Thank you for comment. ids = score > CLS_THRESH if ids.numel() <=1: ids = ids.nonzero().squeeze(0) else: ids = ids.nonzero().squeeze()

I also revised code, but still occur error. Could you give me another solution ?

:) I found my solution did not work too. ids = score > CLS_THRESH is a booleen index [0,0,1,0....]. We can use ids.sum( ) to see whether there is 1 in ids (which means there is positive sample) . Try this: ids = score > CLS_THRESH if not ids.sum(): return torch.tensor([[0.,0.,0.,0.]]), torch.tensor([[0]]) else: do NMS

Also I recommend checking the dimension of these variables.

KyuminHwang commented 5 years ago

@gyfastas Thank you for comment ! I also handle this error like return torch.tensor([[0.,0.,0.,0.]]), torch.tensor([[0]]) . Thank you :)