davidtvs / PyTorch-ENet

PyTorch implementation of ENet
MIT License
389 stars 129 forks source link

RuntimeError: weight tensor should be defined either for all or no classes #35

Closed ouening closed 3 years ago

ouening commented 4 years ago

Hi, when I trained my own dataset ( like camvid, has 4 classes), error hapened:

>>>> [Epoch: 0] Training
Traceback (most recent call last):
  File "main.py", line 306, in <module>
    model = train(train_loader, val_loader, w_class, class_encoding)
  File "main.py", line 191, in train
    epoch_loss, (iou, miou) = train.run_epoch(args.print_step)
  File "/media/gaoya/disk/Applications/pytorch/SemanticSegmentation/PyTorch-ENet-master/train.py", line 47, in run_epoch
    loss = self.criterion(outputs, labels)
  File "/media/gaoya/disk/Applications/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/media/gaoya/disk/Applications/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py", line 916, in forward
    ignore_index=self.ignore_index, reduction=self.reduction)
  File "/media/gaoya/disk/Applications/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 1995, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/media/gaoya/disk/Applications/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 1826, in nll_loss
    ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: weight tensor should be defined either for all or no classes at /tmp/pip-req-build-58y_cjjl/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:27

How can I solve it?

davidtvs commented 4 years ago

All I can say from the error is that your call weight tensor doesn't have the expected dimensions. Quoting the pytorch docs: weight (Tensor, optional) – a manual rescaling weight given to each class. If given, has to be a Tensor of size C

noamzilo commented 3 years ago

same problem for me, for batches whose labels don't have all possible labels. What's the correct way to use weight?

davidtvs commented 3 years ago

You should specify the weights of all classes for all batches regardless of which labels are in any specific batch. PyTorch will take the weight corresponding to the class labels that exist itself when computing the loss. See the NLLLoss docs

ghost commented 3 years ago

if you're a real genius like me, you simply set the wrong number of classes:

class_weights = enet_weighing(train_loader, 3) # whoops - 3 is highest index of a zero-based array, should use 4!