yhenon / pytorch-retinanet

Pytorch implementation of RetinaNet object detection.
Apache License 2.0
2.14k stars 665 forks source link

Little bug in train.py #156

Open AstrayChao opened 4 years ago

AstrayChao commented 4 years ago

missing something

file: train.py line: 126

classification_loss, regression_loss = retinanet([data['img'].cuda().float(), data['annot']])

I think data['annot'] should be data['annot'].cuda()

renmengyuan commented 4 years ago

What is the result if we do not use .cuda() ?