davidtvs / PyTorch-ENet

PyTorch implementation of ENet
MIT License
383 stars 129 forks source link

Cannot run main.py in test mode - 1only batches of spatial targets supported (3D tensors) #56

Closed rashedkoutayni closed 2 years ago

rashedkoutayni commented 2 years ago

When I run the main script for testing using: python3 main.py -m test --save-dir save/ENet_CamVid/ --name ENet --dataset camvid --dataset-dir ../../DATASETS/CamVid/ --batch-size 1

I get an error message: RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1, 3, 360, 480]

This is triggered from test.py:51 when it tries to calculate the loss: loss = self.criterion(outputs, labels)

Any clue on how to solve this issue?

rashedkoutayni commented 2 years ago

By the way, and to get your hands dirty faster, I wrote an inference code that performs the semantic segmentation (CamVid style) using the provided pretrained model and visualizes the result without loading any dataset:

import os
import sys
import cv2

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms

from PIL import Image

import transforms as ext_transforms
from models.enet import ENet

import utils

device = torch.device(0)
model_path = 'save/ENet_CamVid/ENet'

class_encoding = OrderedDict([
    ('sky', (128, 128, 128)),
    ('building', (128, 0, 0)),
    ('pole', (192, 192, 128)),
    ('road', (128, 64, 128)),
    ('pavement', (60, 40, 222)),
    ('tree', (128, 128, 0)),
    ('sign_symbol', (192, 128, 128)),
    ('fence', (64, 64, 128)),
    ('car', (64, 0, 128)),
    ('pedestrian', (64, 64, 0)),
    ('bicyclist', (0, 128, 192)),
    ('unlabeled', (0, 0, 0))
]) 

label_to_rgb = transforms.Compose([
    ext_transforms.LongTensorToRGBPIL(class_encoding),
    transforms.ToTensor() 
])

# Run only if this module is being run directly
if __name__ == '__main__':
    model = ENet(12).to(device)
    model.eval()

    # Load the previoulsy saved model state to the ENet model
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])

    in_image = cv2.imread(sys.argv[1])
    in_image = cv2.cvtColor(in_image, cv2.COLOR_BGR2RGB)
    in_image = in_image.transpose(2, 0, 1)
    in_image = torch.from_numpy(in_image).unsqueeze(0)
    in_image = in_image.to(device).float()/255.
    out_image = model(in_image)

    # Predictions is one-hot encoded with "num_classes" channels.
    # Convert it to a single int using the indices where the maximum (1) occurs
    _, predictions = torch.max(out_image, 1)

    color_predictions = utils.batch_transform(predictions.cpu(), label_to_rgb)
    utils.imshow_batch(in_image.data.cpu(), color_predictions)

This code can be run simply by passing an input image as a parameter to the script: python3 run_inference.py PATH/TO/IMAGE.png

Note that you can set your own model_path to the pretrained ENet model.

davidtvs commented 2 years ago

Looks like a duplicate of #48. I tried to reproduce using the command you posted but it runs fine for me. I'm closing this issue but feel free to provide more details on #48 about your environment and how you got the data.

Regarding the script, you can make a pull request and I'll review it.