Open Cmonsta6 opened 17 hours ago
Also,
parser.add_argument('--input_size', type=int, default=1280)
The argument value is not used in the augmentations. Images and masks during training are not resized. Validation set is resized, but not to the --input_size value.
vd_transform = A.Compose([
A.Resize(width=1024, height=1024)
])
I don't understand how pull requests work, functionally, nor do I think it would be appropriate given the amount of mods I have done beyond this.
The original code used bilinear for validation images and masks.
vd_transform = A.Compose([
A.Resize(width=1024, height=1024)
])
I changed to using nearest resampling for masks for disnet and gt_encoder trainer to maintain the hard edges, and area resampling for rgb images. Area resampling does a slightly better job if the downsampling step is big, and if not then it just looks the same as bilinear.
That noted, here is my solution to the issues
class CustomResizeTransform:
def __init__(self, input_size, is_mask=False):
# For images, use INTER_AREA; for masks, use INTER_NEAREST
self.image_resize = A.Resize(width=input_size, height=input_size, interpolation=cv2.INTER_AREA)
self.mask_resize = A.Resize(width=input_size, height=input_size, interpolation=cv2.INTER_NEAREST)
self.is_mask = is_mask
def __call__(self, image, mask=None):
if self.is_mask:
# When the 'image' is a mask (GT_Encoder), apply nearest-neighbor resampling
image_resized = self.mask_resize(image=image)['image']
return {'image': image_resized}
else:
# For DISNet, resize both image and mask using different interpolations
image_resized = self.image_resize(image=image)['image']
if mask is not None:
mask_resized = self.mask_resize(image=mask)['image'] # Apply nearest-neighbor for mask
return {'image': image_resized, 'mask': mask_resized}
return {'image': image_resized}
def load_dataloader(args):
mask_transform = A.Compose([
A.Resize(width=args.input_size, height=args.input_size, interpolation=cv2.INTER_NEAREST),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.8),
A.RandomRotate90(p=0.8)
])
duo_transform = A.Compose([
A.Resize(width=args.input_size, height=args.input_size),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.8),
A.RandomRotate90(p=0.8),
A.CLAHE(p=0.8), # Only affects the image, not the mask
A.RandomBrightnessContrast(p=0.8), # Only affects the image
A.RandomGamma(p=0.8) # Only affects the image
], additional_targets={'mask': 'image'}) # Ensures same spatial transformations for masks
# Custom resizing for validation with different interpolation methods for images and masks (DISNet)
vd_transform = CustomResizeTransform(input_size=args.input_size)
if args.train_type == 'disnet':
# DISNet dataset (image + mask)
from utils.isnet_dataset import Dataset
from utils.augmentation import RandomBlur
tr_ds = Dataset(image_path=args.tr_im_path, gt_path=args.tr_gt_path,
transform=duo_transform,
random_blur=None,
load_on_mem=args.load_data_on_mem)
vd_ds = Dataset(image_path=args.vd_im_path, gt_path=args.vd_gt_path,
transform=vd_transform, # Use unified custom transform for validation
load_on_mem=args.load_data_on_mem)
else:
# GT_Encoder dataset (same image and gt, where 'image' is a mask)
from utils.gt_dataset import Dataset
# GT_Encoder: Apply augmentations only to the training set (same for image and gt)
tr_ds = Dataset(image_path=args.tr_gt_path, transform=mask_transform)
# For validation, apply nearest-neighbor resizing (since 'images' are masks)
vd_transform = CustomResizeTransform(input_size=args.input_size, is_mask=True)
# Use the resized transform for validation in the GT_Encoder case
vd_ds = Dataset(image_path=args.vd_gt_path, transform=vd_transform)
tr_dl = DataLoader(tr_ds, args.batch_size, shuffle=True, num_workers=8)
vd_dl = DataLoader(vd_ds, args.batch_size, shuffle=False, num_workers=4)
return tr_dl, vd_dl
My edit requires a changes to /utils/isnet_dataset.py too.
import os
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from glob import glob
import cv2
import numpy as np
import albumentations as A
from tqdm import tqdm
from torchvision import transforms as T
class Dataset(Dataset):
def __init__(self, image_path='../../data/DIS5K/DIS-TR/im', gt_path='../../data/DIS5K/DIS-TR/gt',
transform=None, load_on_mem=False, random_blur=None):
self.images = sorted(glob(os.path.join(image_path, '*.jpg')))
self.gts = sorted(glob(os.path.join(gt_path, '*.png')))
self.transform = transform # Unified transformation for both image and mask
self.random_blur = random_blur
print(f'images : {len(self.images)}')
print(f'gts : {len(self.gts)}')
self.load_on_mem = load_on_mem
if self.load_on_mem:
self.load_data()
def __len__(self):
return len(self.gts)
def load_data(self):
self.im_lst = []
self.gt_lst = []
for im, gt in tqdm(zip(self.images, self.gts), total=self.__len__()):
image, gt = cv2.imread(im), cv2.imread(gt, cv2.IMREAD_GRAYSCALE)
self.im_lst.append(image)
self.gt_lst.append(gt)
def _transform(self, image, gt):
# Apply spatial transformations (flip, rotate, resize) to both image and mask
if self.transform:
transformed = self.transform(image=image, mask=gt) # Same transform for both
image, gt = transformed['image'], transformed['mask']
# Blur the image only if specified (optional and independent of spatial transformations)
if self.random_blur:
image = self.random_blur()(image=image)['image'] # Blur only the image, not the mask
# Threshold and normalize the mask (binary values 0 and 1)
gt = (gt > 128).astype(np.float32)
# Normalize image pixel values to [0, 1]
image = (image / 255.0).astype(np.float32)
# Convert to tensor
image = transforms.ToTensor()(image)
gt = transforms.ToTensor()(gt)
return image, gt
def __getitem__(self, idx):
if self.load_on_mem:
image, gt = self.im_lst[idx], self.gt_lst[idx]
else:
image, gt = cv2.imread(self.images[idx]), cv2.imread(self.gts[idx], cv2.IMREAD_GRAYSCALE)
# Apply spatial transformations (flip, rotate, resize) and optional blur
image, gt = self._transform(image, gt)
return {'image': image, 'gt': gt}
If that doesn't work my humble apologies, for I am but a humble glue eater slapping code together. Oh and for myself at least, --input_size has to be power of 2 sizes, otherwise there's dimensional mismatch.
x = torch.cat([out, skip_x], dim=1) # dim 1 is the channel dimension
RuntimeError: Sizes of tensors must match except in dimension 1.
I have been using your very helpful script to train GT_Encoder, first time using Lightning and it's really great.
I was about to start training Disnet and I noticed something.
image_transform and gt_transform are using different augmentations.
This would random crop and rotate/flip the masks, but not the images.