milesial / Pytorch-UNet

PyTorch implementation of the U-Net for image semantic segmentation with high quality images
GNU General Public License v3.0
9.32k stars 2.52k forks source link

Using Per Class IoU to evaluate trained model #465

Open MjdMahasneh opened 1 year ago

MjdMahasneh commented 1 year ago

first, let me thank you for the amazing repo, many thanks :)

In my project, I needed to evaluate using IoU (for consistency) and to get the per-class score. here is my evaluate_using_IoU.py (I have tested it and to the best of my knowledge it works as expected):

import logging
import os
import torch
from pathlib import Path

from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm

from evaluate import evaluate
from unet import UNet
from utils.data_loading import BasicDataset

def iou_score(output, target):
    smooth = 1e-5 ## a small constant added to the numerator and denominator) is a common practice to prevent division by zero in cases where the intersection and union might be zero, leading to an undefined IoU value

    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
        output = (output > 0.5).astype(int)
    else:
        output = (output > 0.5).astype(int)

    if torch.is_tensor(target):
        target = target.data.cpu().numpy().astype(int)

    intersection = (output & target).sum()
    union = (output | target).sum()

    iou = (intersection + smooth) / (union + smooth)

    return iou

@torch.inference_mode()
def evaluate_with_iou(net, dataloader, device, amp):
    net.eval()
    num_val_batches = len(dataloader)
    total_iou = 0.0

    # Initializing a list to store the IoU for each class over all batches
    classwise_iou = [0.0] * net.n_classes

    with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
        for batch in tqdm(dataloader, total=num_val_batches, desc='IoU evaluation', unit='batch', leave=False):
            image, mask_true = batch['image'], batch['mask']

            image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            mask_true = mask_true.to(device=device, dtype=torch.long)

            mask_pred = net(image)

            batch_iou = 0.0  # IoU accumulator for the batch

            for cls in range(net.n_classes):  # Now including the background
                mask_pred_cls = (mask_pred.argmax(dim=1) == cls).float()
                mask_true_cls = (mask_true == cls).float()

                iou_cls = iou_score(mask_pred_cls, mask_true_cls)

                batch_iou += iou_cls
                classwise_iou[cls] += iou_cls  # Adding to the respective class

            batch_iou /= net.n_classes  # Average the IoU over all classes

            total_iou += batch_iou

    # Average classwise IoU over all batches
    classwise_iou = [iou / max(num_val_batches, 1) for iou in classwise_iou]

    return total_iou / max(num_val_batches, 1), classwise_iou

class Config:
    '''Configuration class for training
        Usage:
            args = Config()
            print(vars(args))
            print(args.epochs)
    '''
    def __init__(self):

        self.batch_size = 2
        self.bilinear = False
        self.classes = 3 #2
        self.target_size = (512, 512) ## (height, width)

        self.dir_root = Path('G:/Datasets')
        self.train_images_dir = Path(os.path.join(self.dir_root, 'train/images'))
        self.train_mask_dir = Path(os.path.join(self.dir_root, 'train/masks'))
        self.val_images_dir = Path(os.path.join(self.dir_root, 'val/images'))
        self.val_mask_dir = Path(os.path.join(self.dir_root, 'val/masks'))

        self.model = './checkpoints/checkpoint_epoch5.pth'

if __name__ == '__main__':

    args = Config()
    print('args : ', vars(args))

    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    model = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Loading model {args.model}')
    logging.info(f'Using device {device}')

    state_dict = torch.load(args.model, map_location=device)
    del state_dict['mask_values']
    model.load_state_dict(state_dict)
    model.to(device=device)

    logging.info(f'Network:\n'
                 f'\t{model.n_channels} input channels\n'
                 f'\t{model.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')

    assert args.val_images_dir is not None and args.val_mask_dir is not None, 'Please provide the path to the images directory'

    # Create datasets
    val_dataset = BasicDataset(args.val_images_dir, args.val_mask_dir, mask_suffix='', target_size=args.target_size, stage='val')
    n_val = len(val_dataset)
    loader_args = dict(batch_size=args.batch_size, num_workers=os.cpu_count(), pin_memory=True)

    # Create data loaders
    val_loader = DataLoader(val_dataset, shuffle=False, drop_last=True, **loader_args)

    logging.info(f'''Starting IoU evaluation:
        Batch size:      {args.batch_size}
        Validation size: {n_val}
        Device:          {device.type}
        ''')

    ## uncomment this to evaluate with Dice score
    # val_score = evaluate(model, val_loader, device, amp=False)
    # logging.info('Validation Dice score: {}'.format(val_score))

    val_score, classwise_scores = evaluate_with_iou(model, val_loader, device, amp=False)
    logging.info('Validation IoU score: {}'.format(val_score))
    for i, cls_iou in enumerate(classwise_scores):
        logging.info(f'Class {i} IoU score: {cls_iou}')

To run, just make sure you modify config class and run.

You could also include it in your repo if you think its useful.

Hope this helps.

It would be nice to get the per-class Dice Score too, maybe at some point in the future.