davidtvs / PyTorch-ENet

PyTorch implementation of ENet
MIT License
389 stars 129 forks source link

About Training #25

Closed LinRui9531 closed 5 years ago

LinRui9531 commented 5 years ago

When i want to train your ENet on my device, i run with following command:

python main.py -m train --save-dir ./camvid_model/ --name ENet --dataset camvid --dataset-dir CamVid/ --with-unlabeled --imshow-batch

But I meet the followed problems. Traceback (most recent call last): File "main.py", line 291, in loaders, w_class, class_encoding = load_dataset(dataset) File "main.py", line 110, in load_dataset color_labels = utils.batch_transform(labels, label_to_rgb) File "/home/amax/linrui/PyTorch-ENet-master/utils.py", line 21, in batch_transform transf_slices = [transform(tensor) for tensor in torch.unbind(batch)] File "/usr/local/lib/python2.7/dist-packages/torchvision/transforms/transforms.py", line 49, in call img = t(img) File "/home/amax/linrui/PyTorch-ENet-master/transforms.py", line 91, in call color_tensor[channel].maskedfill(mask, color_value) RuntimeError: expand(torch.ByteTensor{[3, 360, 480]}, size=[360, 480]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)

The label.size() is [B,3,H,W], And how to translate to [B,classnumber,H,W]. I didn't find in your code? Thank you

davidtvs commented 5 years ago

The CamVid labels from the authors of SegNet already come with the class numbers instead of RGB color. The label should be of dimension [H, W]. Note also that this is what the loss function expects to get as the label as described here.

So, unless you're not using the CamVid dataset from SegNet you wouldn't need to go from [3,H,W] to [classnumber,H,W].

LinRui9531 commented 5 years ago

3Q~ I will download the CamVid dataset from SegNet again and try to run your code.

xxxxxxxiao commented 5 years ago

Hi. I am trying to use other dataset so do i need to translate from [3,h,w] to [classnumber, h, w]? and could you tell me how to do that?

davidtvs commented 5 years ago

@xxxxxxxiao you need to convert from [3,h,w] to [h,w], this is what the loss function expects. Each pixel is 0≤targets[i]≤C−1, where C is the number of classes.

As to how you make the conversion, I would start with the following:

from PIL import Image
import numpy as np

img = np.array(Image.open("path"))
color_to_class = [(50, 50, 50), ...]
target = np.zeros(img.shape[:-1])
for c, color in enumerate(color_to_class):
    target = np.any(img == color, axis=-1) * c