Closed LinRui9531 closed 5 years ago
The CamVid labels from the authors of SegNet already come with the class numbers instead of RGB color. The label should be of dimension [H, W]. Note also that this is what the loss function expects to get as the label as described here.
So, unless you're not using the CamVid dataset from SegNet you wouldn't need to go from [3,H,W] to [classnumber,H,W].
3Q~ I will download the CamVid dataset from SegNet again and try to run your code.
Hi. I am trying to use other dataset so do i need to translate from [3,h,w] to [classnumber, h, w]? and could you tell me how to do that?
@xxxxxxxiao you need to convert from [3,h,w] to [h,w], this is what the loss function expects. Each pixel is 0≤targets[i]≤C−1, where C is the number of classes.
As to how you make the conversion, I would start with the following:
from PIL import Image
import numpy as np
img = np.array(Image.open("path"))
color_to_class = [(50, 50, 50), ...]
target = np.zeros(img.shape[:-1])
for c, color in enumerate(color_to_class):
target = np.any(img == color, axis=-1) * c
When i want to train your ENet on my device, i run with following command:
python main.py -m train --save-dir ./camvid_model/ --name ENet --dataset camvid --dataset-dir CamVid/ --with-unlabeled --imshow-batch
But I meet the followed problems. Traceback (most recent call last): File "main.py", line 291, in
loaders, w_class, class_encoding = load_dataset(dataset)
File "main.py", line 110, in load_dataset
color_labels = utils.batch_transform(labels, label_to_rgb)
File "/home/amax/linrui/PyTorch-ENet-master/utils.py", line 21, in batch_transform
transf_slices = [transform(tensor) for tensor in torch.unbind(batch)]
File "/usr/local/lib/python2.7/dist-packages/torchvision/transforms/transforms.py", line 49, in call
img = t(img)
File "/home/amax/linrui/PyTorch-ENet-master/transforms.py", line 91, in call
color_tensor[channel].maskedfill(mask, color_value)
RuntimeError: expand(torch.ByteTensor{[3, 360, 480]}, size=[360, 480]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)
The label.size() is [B,3,H,W], And how to translate to [B,classnumber,H,W]. I didn't find in your code? Thank you