davidtvs / PyTorch-ENet

PyTorch implementation of ENet
MIT License
389 stars 129 forks source link

Unable to use pre-trained models #55

Closed rdthrow closed 3 years ago

rdthrow commented 3 years ago

Hi, I created a simple test script using the predict function you have provided. When I load the pre-trained model and run on a sample cityscape image, I get completely garbage segmentation mask.

I looked at the output prediction tensor here:

images = images.to('cuda')

    # Make predictions!
    model.eval()
    with torch.no_grad():
        predictions = model(images)

The predictions for group of nearby pixels(for e.g. from section of road) appear completely random for a typical cityscape image. Another questions: is the output tensor supposed to be a normalized prob vector? I see negative values in this tensor:

(Pdb) predictions.data[:,:,5,5]
tensor([[ 0.2902, -0.3427,  0.1710,  0.0719, -0.6740, -0.4032, -0.0204, -0.0356,
          0.4871,  0.4981,  0.1949,  0.4358]], device='cuda:0')

I'm using :

torch.version '1.9.0+cu102'

Following is my complete test script:

import sys
import numpy as np
import torch
import models.enet
import torch.onnx
from PIL import Image
from transforms import PILToLongTensor
import utils
from collections import OrderedDict
import torchvision.transforms as transforms
import transforms as ext_transforms
import matplotlib.pyplot as plt

color_encoding = OrderedDict([
        ('unlabeled', (0, 0, 0)),
        ('road', (128, 64, 128)),
        ('sidewalk', (244, 35, 232)),
        ('building', (70, 70, 70)),
        ('wall', (102, 102, 156)),
        ('fence', (190, 153, 153)),
        ('pole', (153, 153, 153)),
        ('traffic_light', (250, 170, 30)),
        ('traffic_sign', (220, 220, 0)),
        ('vegetation', (107, 142, 35)),
        ('terrain', (152, 251, 152)),
        ('sky', (70, 130, 180)),
        ('person', (220, 20, 60)),
        ('rider', (255, 0, 0)),
        ('car', (0, 0, 142)),
        ('truck', (0, 0, 70)),
        ('bus', (0, 60, 100)),
        ('train', (0, 80, 100)),
        ('motorcycle', (0, 0, 230)),
        ('bicycle', (119, 11, 32))
])

def predict(model, images, class_encoding):
    images = images.to('cuda')

    # Make predictions!
    model.eval()
    with torch.no_grad():
        predictions = model(images)

    # Predictions is one-hot encoded with "num_classes" channels.
    # Convert it to a single int using the indices where the maximum (1) occurs
    _, predictions = torch.max(predictions.data, 1)

    label_to_rgb = transforms.Compose([
        ext_transforms.LongTensorToRGBPIL(class_encoding),
        transforms.ToTensor()
    ])
    color_predictions = utils.batch_transform(predictions.cpu(), label_to_rgb)
    utils.imshow_batch(images.data.cpu(), color_predictions)

weights = torch.load(sys.argv[1])
mymodel = models.enet.ENet(19)

mymodel.load_state_dict(weights, strict=False)

import torchvision
img = torchvision.io.read_image(sys.argv[2])

img = torch.reshape(img,(1,3,1024,2048))
mymodel.to('cuda')
predict(mymodel,img.float(),color_encoding)
rdthrow commented 3 years ago

Sorry, my bad. Was loading the model incorrectly. Closing.