skuley / DIS-Multi-Gpus-Training

DISNet training with multi-gpus
Apache License 2.0
7 stars 1 forks source link

Unequal augmentations within load_dataloader for 'disnet' training, in train_dis.py #5

Open Cmonsta6 opened 17 hours ago

Cmonsta6 commented 17 hours ago

I have been using your very helpful script to train GT_Encoder, first time using Lightning and it's really great.

I was about to start training Disnet and I noticed something.

    if args.train_type == 'disnet':
        from utils.isnet_dataset import Dataset
        from utils.augmentation import RandomBlur
        tr_ds = Dataset(image_path=args.tr_im_path, gt_path=args.tr_gt_path,
                        image_transform=image_transform,
                        gt_transform=mask_transform,
                        random_blur=None,
                        load_on_mem=args.load_data_on_mem)

image_transform and gt_transform are using different augmentations.

def load_dataloader(args):    

    mask_transform = A.Compose([
        # A.Resize(width=args.input_size, height=args.input_size),
        A.RandomCrop(width=1024, height=1024),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.8),
        A.RandomRotate90(p=0.8)
    ])

    image_transform = A.Compose([
        A.CLAHE(p=0.8),
        A.RandomBrightnessContrast(p=0.8),
        A.RandomGamma(p=0.8)]
    )

This would random crop and rotate/flip the masks, but not the images.

Cmonsta6 commented 17 hours ago

Also,

parser.add_argument('--input_size', type=int, default=1280)

The argument value is not used in the augmentations. Images and masks during training are not resized. Validation set is resized, but not to the --input_size value.


    vd_transform = A.Compose([
        A.Resize(width=1024, height=1024)
    ])
Cmonsta6 commented 16 hours ago

I don't understand how pull requests work, functionally, nor do I think it would be appropriate given the amount of mods I have done beyond this.

The original code used bilinear for validation images and masks.

    vd_transform = A.Compose([
        A.Resize(width=1024, height=1024)
    ])

I changed to using nearest resampling for masks for disnet and gt_encoder trainer to maintain the hard edges, and area resampling for rgb images. Area resampling does a slightly better job if the downsampling step is big, and if not then it just looks the same as bilinear.

That noted, here is my solution to the issues

class CustomResizeTransform:
    def __init__(self, input_size, is_mask=False):
        # For images, use INTER_AREA; for masks, use INTER_NEAREST
        self.image_resize = A.Resize(width=input_size, height=input_size, interpolation=cv2.INTER_AREA)
        self.mask_resize = A.Resize(width=input_size, height=input_size, interpolation=cv2.INTER_NEAREST)
        self.is_mask = is_mask

    def __call__(self, image, mask=None):
        if self.is_mask:
            # When the 'image' is a mask (GT_Encoder), apply nearest-neighbor resampling
            image_resized = self.mask_resize(image=image)['image']
            return {'image': image_resized}
        else:
            # For DISNet, resize both image and mask using different interpolations
            image_resized = self.image_resize(image=image)['image']
            if mask is not None:
                mask_resized = self.mask_resize(image=mask)['image']  # Apply nearest-neighbor for mask
                return {'image': image_resized, 'mask': mask_resized}
            return {'image': image_resized}

def load_dataloader(args):    
    mask_transform = A.Compose([
        A.Resize(width=args.input_size, height=args.input_size, interpolation=cv2.INTER_NEAREST),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.8),
        A.RandomRotate90(p=0.8)
    ])

    duo_transform = A.Compose([
        A.Resize(width=args.input_size, height=args.input_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.8),
        A.RandomRotate90(p=0.8),
        A.CLAHE(p=0.8),  # Only affects the image, not the mask
        A.RandomBrightnessContrast(p=0.8),  # Only affects the image
        A.RandomGamma(p=0.8)  # Only affects the image
    ], additional_targets={'mask': 'image'})  # Ensures same spatial transformations for masks

    # Custom resizing for validation with different interpolation methods for images and masks (DISNet)
    vd_transform = CustomResizeTransform(input_size=args.input_size)

    if args.train_type == 'disnet':
        # DISNet dataset (image + mask)
        from utils.isnet_dataset import Dataset
        from utils.augmentation import RandomBlur
        tr_ds = Dataset(image_path=args.tr_im_path, gt_path=args.tr_gt_path,
                        transform=duo_transform,
                        random_blur=None,
                        load_on_mem=args.load_data_on_mem)
        vd_ds = Dataset(image_path=args.vd_im_path, gt_path=args.vd_gt_path,
                        transform=vd_transform,  # Use unified custom transform for validation
                        load_on_mem=args.load_data_on_mem)
    else:
        # GT_Encoder dataset (same image and gt, where 'image' is a mask)
        from utils.gt_dataset import Dataset
        # GT_Encoder: Apply augmentations only to the training set (same for image and gt)
        tr_ds = Dataset(image_path=args.tr_gt_path, transform=mask_transform)

        # For validation, apply nearest-neighbor resizing (since 'images' are masks)
        vd_transform = CustomResizeTransform(input_size=args.input_size, is_mask=True)

        # Use the resized transform for validation in the GT_Encoder case
        vd_ds = Dataset(image_path=args.vd_gt_path, transform=vd_transform)

    tr_dl = DataLoader(tr_ds, args.batch_size, shuffle=True, num_workers=8)
    vd_dl = DataLoader(vd_ds, args.batch_size, shuffle=False, num_workers=4)

    return tr_dl, vd_dl

My edit requires a changes to /utils/isnet_dataset.py too.

import os
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from glob import glob
import cv2
import numpy as np
import albumentations as A
from tqdm import tqdm
from torchvision import transforms as T

class Dataset(Dataset):
    def __init__(self, image_path='../../data/DIS5K/DIS-TR/im', gt_path='../../data/DIS5K/DIS-TR/gt',
                 transform=None, load_on_mem=False, random_blur=None):
        self.images = sorted(glob(os.path.join(image_path, '*.jpg')))
        self.gts = sorted(glob(os.path.join(gt_path, '*.png')))

        self.transform = transform  # Unified transformation for both image and mask
        self.random_blur = random_blur

        print(f'images : {len(self.images)}')
        print(f'gts : {len(self.gts)}')

        self.load_on_mem = load_on_mem
        if self.load_on_mem:
            self.load_data()

    def __len__(self):
        return len(self.gts)

    def load_data(self):
        self.im_lst = []
        self.gt_lst = []
        for im, gt in tqdm(zip(self.images, self.gts), total=self.__len__()):
            image, gt = cv2.imread(im), cv2.imread(gt, cv2.IMREAD_GRAYSCALE)
            self.im_lst.append(image)
            self.gt_lst.append(gt)

    def _transform(self, image, gt):
        # Apply spatial transformations (flip, rotate, resize) to both image and mask
        if self.transform:
            transformed = self.transform(image=image, mask=gt)  # Same transform for both
            image, gt = transformed['image'], transformed['mask']

        # Blur the image only if specified (optional and independent of spatial transformations)
        if self.random_blur:
            image = self.random_blur()(image=image)['image']  # Blur only the image, not the mask

        # Threshold and normalize the mask (binary values 0 and 1)
        gt = (gt > 128).astype(np.float32)

        # Normalize image pixel values to [0, 1]
        image = (image / 255.0).astype(np.float32)

        # Convert to tensor
        image = transforms.ToTensor()(image)
        gt = transforms.ToTensor()(gt)

        return image, gt

    def __getitem__(self, idx):
        if self.load_on_mem:
            image, gt = self.im_lst[idx], self.gt_lst[idx]
        else:
            image, gt = cv2.imread(self.images[idx]), cv2.imread(self.gts[idx], cv2.IMREAD_GRAYSCALE)

        # Apply spatial transformations (flip, rotate, resize) and optional blur
        image, gt = self._transform(image, gt)

        return {'image': image, 'gt': gt}

If that doesn't work my humble apologies, for I am but a humble glue eater slapping code together. Oh and for myself at least, --input_size has to be power of 2 sizes, otherwise there's dimensional mismatch.

x = torch.cat([out, skip_x], dim=1)  # dim 1 is the channel dimension

RuntimeError: Sizes of tensors must match except in dimension 1.