davidtvs / PyTorch-ENet

PyTorch implementation of ENet
MIT License
389 stars 129 forks source link

Problems while using different dataset #45

Closed robsoncsantiago closed 3 years ago

robsoncsantiago commented 4 years ago

Hello, I tried to adapt the code into a .ipynb so I could run isolated cells and check how the pipeline works, as I'm trying to evaluate the ENet performance over learning and predicting underwater images (SUIM Dataset, link), but I'm facing some problems.

When I run the cell for training, it throws the following error: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4. Have you faced similar issue before? Any help would be greatly appreciated!

Follow code below with minor changes:

import torch.nn.functional as F
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.functional as TF
import torchvision
import torchvision.transforms as transforms
import transforms as ext_transforms
import torch.optim.lr_scheduler as lr_scheduler
import os
import numpy as np
import matplotlib.pyplot as plt
import glob
from collections import OrderedDict
from torch.utils.data import Dataset, DataLoader

# ENET PYTORCH GITHUB LIBS
import utils
import tools
from PIL import Image
from enet import ENet
from iou import IoU
from train import Train
from test import Test

# Configuring images size
std_size = 256

# Setting device for torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

image_transform = transforms.Compose(
    [transforms.CenterCrop((std_size, std_size)),
      transforms.ToTensor()])

label_transform = transforms.Compose([
    transforms.CenterCrop((std_size, std_size)),
    ext_transforms.PILToLongTensor()])

root_dir = '/content/drive/My Drive/Colab Notebooks'
save_dir = '/content/drive/My Drive/Colab Notebooks/save'

# Training dataset root folders
train_folder = os.path.normpath(root_dir + '/train/images')
train_lbl_folder = os.path.normpath(root_dir + '/train/masks')

# Validation dataset root folders
val_folder = os.path.normpath(root_dir + '/val/images')
val_lbl_folder = os.path.normpath(root_dir + '/val/masks')

# Test dataset root folders
test_folder = os.path.normpath(root_dir + '/test/images')
test_lbl_folder = os.path.normpath(root_dir + '/test/masks')

class CustomDataset(Dataset):
    """Custom Dataset based on CamVid dataset found on:
      https://github.com/davidtvs/PyTorch-ENet.

    Student disclaimer: most parts of this code were used and adapted
    for academic purposes only, with no commercial intents. All rights
    reserved to original author. Please refer to the url cited above.

    Keyword arguments:
    - root_dir (``string``): Root directory path.
    - mode (``string``): The type of dataset: 'train' for training set, 'val'
    for validation set, and 'test' for test set.
    - transform (``callable``, optional): A function/transform that  takes in
    an PIL image and returns a transformed version. Default: None.
    - label_transform (``callable``, optional): A function/transform that takes
    in the target and transforms it. Default: None.
    - loader (``callable``, optional): A function to load an image given its
    path. By default ``default_loader`` is used.

    """

    img_extension = '.jpg'
    label_extension = '.bmp'

    color_encoding = OrderedDict([
        ('Background', (0,0,0)),
        ('Human Divers', (0,0,255)),
        ('Aquatic Plants and Sea-Grass', (0,255,0)),
        ('Wrecks and Ruins', (0,255,255)),
        ('Robots', (255,0,0)),
        ('Reefs and Intertebrates', (255,0,255)),
        ('Fishs and Vertebrates', (255,255,0)),
        ('Sea-Floor and Rocks', (255,255,255))
    ])

    def __init__(self, mode = 'train', transform=None, 
                 label_transform = None, loader = tools.pil_loader):
        self.mode = mode
        self.transform = transform
        self.label_transform = label_transform
        self.loader = loader

        if self.mode.lower() == 'train':
            # Get the training data and labels filepaths
            self.train_data = tools.get_files(
                train_folder, extension_filter=self.img_extension)

            self.train_labels = tools.get_files(
                train_lbl_folder, extension_filter=self.label_extension)

        elif self.mode.lower() == 'val':
            # Get the validation data and labels filepaths
            self.val_data = tools.get_files(
                val_folder, extension_filter=self.img_extension)

            self.val_labels = tools.get_files(
                val_lbl_folder, extension_filter=self.label_extension)

        elif self.mode.lower() == 'test':
            # Get the test data and labels filepaths
            self.test_data = tools.get_files(
                test_folder, extension_filter=self.img_extension)

            self.test_labels = tools.get_files(
                test_lbl_folder, extension_filter=self.label_extension)

        else:
            raise RuntimeError("Unexpected dataset mode. "
                               "Supported modes are: train, val and test")

    def __getitem__(self, index):

        """
        Args:
        - index (``int``): index of the item in the dataset

        Returns:
        A tuple of ``PIL.Image`` (image, label) where label is the ground-truth
        of the image.

        """
        if self.mode.lower() == 'train':
            data_path, label_path = self.train_data[index], self.train_labels[
                index]
        elif self.mode.lower() == 'val':
            data_path, label_path = self.val_data[index], self.val_labels[
                index]
        elif self.mode.lower() == 'test':
            data_path, label_path = self.test_data[index], self.test_labels[
                index]
        else:
            raise RuntimeError("Unexpected dataset mode. "
                               "Supported modes are: train, val and test")

        img, label = self.loader(data_path, label_path)

        if self.transform is not None:
            img = self.transform(img)

        if self.label_transform is not None:
            label = self.label_transform(label)

        return img, label

    def __len__(self):
        """Returns the length of the dataset."""
        if self.mode.lower() == 'train':
            return len(self.train_data)
        elif self.mode.lower() == 'val':
            return len(self.val_data)
        elif self.mode.lower() == 'test':
            return len(self.test_data)
        else:
            raise RuntimeError("Unexpected dataset mode. "
                               "Supported modes are: train, val and test")

# Setting Dataloader variables
mode = input('SELECT MODE OF OPERATION: train, val or test: ')
batch_size = 4
num_workers = 0

# Load the training set as tensors
train_set = CustomDataset(
    transform=image_transform,
    label_transform=label_transform)
train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers)

# Load the validation set as tensors
val_set = CustomDataset(
    mode='val',
    transform=image_transform,
    label_transform=label_transform)
val_loader = DataLoader(
    val_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers)

# Load the test set as tensors
test_set = CustomDataset(
    mode='test',
    transform=image_transform,
    label_transform=label_transform)
test_loader = DataLoader(
    root_dir,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers)

# Retrieving color_encoding
class_encoding = train_set.color_encoding

# Get number of classes to predict
num_classes = len(class_encoding)

# Print information for debugging
print("Number of classes to predict:", num_classes)
print("Train dataset size:", len(train_set))
print("Validation dataset size:", len(val_set))

# Get class weights from the selected weighing technique
weighing = 'enet'
ignore_unlabeled = False

print("\nWeighing technique:", weighing)
print("Computing class weights...")
print("(this can take a while depending on the dataset size)")

class_weights = 0

if weighing.lower() == 'enet':
    class_weights = tools.enet_weighing(train_loader, num_classes)
elif weighing.lower() == 'mfb':
    class_weights = tools.median_freq_balancing(train_loader, num_classes)
else:
    class_weights = None

if class_weights is not None:
    class_weights = torch.from_numpy(class_weights).float().to(device)
    # Set the weight of the unlabeled class to 0
    if ignore_unlabeled:
        ignore_index = list(class_encoding).index('unlabeled')
        class_weights[ignore_index] = 0

print("Class weights:", class_weights)

class_weights = None

learning_rate = 0.05
weight_decay = 0.1
lr_decay_epochs = 10
lr_decay = 0.1

# Intialize ENet
model = ENet(num_classes).to(device)
# Check if the network architecture is correct
# print(model)

# We are going to use the CrossEntropyLoss loss function as it's most
# frequentely used in classification problems with multiple classes which
# fits the problem. This criterion  combines LogSoftMax and NLLLoss.
criterion = nn.CrossEntropyLoss(weight=class_weights)

# ENet authors used Adam as the optimizer
optimizer = optim.Adam(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay)

# Learning rate decay scheduler
lr_updater = lr_scheduler.StepLR(optimizer, lr_decay_epochs,
                                  lr_decay)

# Evaluation metric
metric = IoU(num_classes, ignore_index=False)

# Optionally resume from a checkpoint
resume = True
resume = False
name = 'test'

if resume:
    model, optimizer, start_epoch, best_miou = utils.load_checkpoint(
        model, optimizer, save_dir, name)
    print("Resuming from model: Start epoch = {0} "
          "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
else:
    start_epoch = 0
    best_miou = 0

epochs = 10

train = Train(model, train_loader, optimizer, criterion, metric, device)
val = Test(model, val_loader, criterion, metric, device)
for epoch in range(start_epoch, epochs):
    print(">>>> [Epoch: {0:d}] Training".format(epoch))

    lr_updater.step()
    epoch_loss, (iou, miou) = train.run_epoch(True)

    print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
          format(epoch, epoch_loss, miou))

    if (epoch + 1) % 10 == 0 or epoch + 1 == epochs:
        print(">>>> [Epoch: {0:d}] Validation".format(epoch))

        loss, (iou, miou) = val.run_epoch(True)

        print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
              format(epoch, loss, miou))

        # Print per class IoU on last epoch or if best iou
        if epoch + 1 == epochs or miou > best_miou:
            for key, class_iou in zip(class_encoding.keys(), iou):
                print("{0}: {1:.4f}".format(key, class_iou))

        # Save the model if it's the best thus far
        if miou > best_miou:
            print("\nBest model thus far. Saving...\n")
            best_miou = miou
            #utils.save_checkpoint(model, optimizer, epoch + 1, best_miou)
davidtvs commented 3 years ago

I don't see your pil_loader function in the code above but I'm assuming it's the same as the one in this repository, it simply reads the image with Image.open(). Since your ground-truth images are RGB it will have size (3, H, W). And, adding the batch dimension, (N, 3, H, W).

The loss function CrossEntropyLoss expects a tensor of size (N, H, W), where each pixel value in N is in the interval [0, C-1], where C is the number of classes. Here is the source of the error, your batch of target images doesn't have the correct dimension.

The reason for that is that the dataset you are using encodes the class in RGB values instead of a single value. This is not a problem for CamVid and CityScapes because they happen to provide the target images in the desired format.

TLDR, in a new custom transform or in your CustomDataset.__get_item__() you need to convert those RGB values into a single value and make sure that it returns tensors with size (H, W)

robsoncsantiago commented 3 years ago

Got it, I'll include a custom transform with this process and I'll update here with the results. Thanks for the support!

robsoncsantiago commented 3 years ago

I've changed '''CustomDataset.__get_item__()''' as follows:

 def __getitem__(self, index):

        """
        Args:
        - index (``int``): index of the item in the dataset

        Returns:
        A tuple of ``PIL.Image`` (image, label) where label is the ground-truth
        of the image.

        """
        if self.mode.lower() == 'train':
            data_path, label_path = self.train_data[index], self.train_labels[
                index]
        elif self.mode.lower() == 'val':
            data_path, label_path = self.val_data[index], self.val_labels[
                index]
        elif self.mode.lower() == 'test':
            data_path, label_path = self.test_data[index], self.test_labels[
                index]
        else:
            raise RuntimeError("Unexpected dataset mode. "
                               "Supported modes are: train, val and test")

        img, label = self.loader(data_path, label_path)

        if self.transform is not None:
            img = self.transform(img)

        if self.label_transform is not None:
            label = self.label_transform(label)
            label = tools.rgb2mask(np.array(label), color_encoding)
            label = transforms.ToTensor()(label).long()
            label = label.squeeze(0)

        return img, label

With rgb2mask as:

def rgb2mask(img, color2index):

    assert len(img.shape) == 3
    height, width, ch = img.shape
    assert ch == 3

    W = np.power(256, [[0],[1],[2]])

    img_id = img.dot(W).squeeze(-1) 
    values = np.unique(img_id)

    mask = np.zeros(img_id.shape)

    for i, c in enumerate(values):
        try:
            mask[img_id==c] = color2index[tuple(img[img_id==c][0])] 
        except:
            pass
    return mask

So, apparently it is training and decreasing Iteration Loss through time, but when it comes to compute IoU, it throws the following error:

image

Any thoughts on that? Trying to debug it but no success so far... Again, any help would be greatly appreciated!

davidtvs commented 3 years ago

IoU(num_classes, ignore_index=False) is the problem. You are passing False to an argument that's expected to be an int or iterable. If you don't want to ignore any class simply do IoU(num_classes), ignore_index defaults to None in which case no class is ignored.

robsoncsantiago commented 3 years ago

Gosh, stupid mistake... Now everything seems to be running smoothly as it should be. Any interesting outcomes from this work I'll make sure to post here.

Thanks for the support!