NVlabs / SegFormer

Official PyTorch implementation of SegFormer
https://arxiv.org/abs/2105.15203
Other
2.58k stars 357 forks source link

Black masks prediction #156

Open MichaelSchroter opened 2 months ago

MichaelSchroter commented 2 months ago

Hi All, I used this below code to fine tune a SegFormer model for image segmentation. It is as below.


!pip install --upgrade transformers

from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch
from torch import nn
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
import os
from PIL import Image
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
import pandas as pd
import cv2
import numpy as np
import albumentations as aug

import random
import shutil
import os

def create_subset(source_img_dir, source_mask_dir, target_img_dir, target_mask_dir, num_samples):
    """
    Creates a subset of images and masks, ensuring corresponding pairs.

    Args:
        source_img_dir (str): Path to the source image directory.
        source_mask_dir (str): Path to the source mask directory.
        target_img_dir (str): Path to the target image directory.
        target_mask_dir (str): Path to the target mask directory.
        num_samples (int): Number of samples to include in the subset.
    """

    # Ensure target directories exist
    os.makedirs(target_img_dir, exist_ok=True)
    os.makedirs(target_mask_dir, exist_ok=True)

    # Get all image filenames
    all_images = os.listdir(source_img_dir)

    # Randomly sample image filenames
    sampled_images = random.sample(all_images, num_samples)

    for img_file in sampled_images:
        shutil.copy(os.path.join(source_img_dir, img_file), os.path.join(target_img_dir, img_file))
        shutil.copy(os.path.join(source_mask_dir, img_file), os.path.join(target_mask_dir, img_file))

# Define directories
source_img_dir = '/kaggle/input/das-data-nn/FAS_DataNN/YCrCb3/images'
source_mask_dir = '/kaggle/input/das-data-nn/FAS_DataNN/YCrCb3/masks'
target_img_dir = '/kaggle/working/subset/subsample/images'
target_mask_dir = '/kaggle/working/subset/subsample/masks'

# Create a subset with a specified number of samples
create_subset(source_img_dir, source_mask_dir, target_img_dir, target_mask_dir, num_samples=500)

source_img_dir = '/kaggle/input/das-data-nn/FAS_DataNN/patched3/images'
source_mask_dir = '/kaggle/input/das-data-nn/FAS_DataNN/patched3/masks'
target_img_dir = '/kaggle/working/subset/subsample/images'
target_mask_dir = '/kaggle/working/subset/subsample/masks'

# Create a subset with a specified number of samples
create_subset(source_img_dir, source_mask_dir, target_img_dir, target_mask_dir, num_samples=500)

# Constants
IMAGE_SIZE = 224

class CustomImageSegmentationDataset(Dataset):
    """Custom image segmentation dataset."""

    def __init__(self, image_dirs, mask_dirs=None, feature_extractor=None, transforms=None, train=True):
        """
        Args:
            image_dirs (list of string): List of directories containing images.
            mask_dirs (list of string or None): List of directories containing masks or None if not available.
            feature_extractor (SegFormerFeatureExtractor or None): Feature extractor for image and segmentation maps.
            transforms (callable or None): Transformations to apply to images and masks.
            train (bool): Whether to load training or validation images.
        """
        self.image_dirs = image_dirs
        self.mask_dirs = mask_dirs
        self.feature_extractor = feature_extractor
        self.transforms = transforms

        # Collect image and mask file paths
        self.images = []
        self.masks = []

        for img_dir in image_dirs:
            for root, _, files in os.walk(img_dir):
                for file in files:
                    img_path = os.path.join(root, file)
                    self.images.append(img_path)

        if mask_dirs:
            for mask_dir in mask_dirs:
                for root, _, files in os.walk(mask_dir):
                    for file in files:
                        mask_path = os.path.join(root, file)
                        if os.path.exists(mask_path):
                            self.masks.append(mask_path)

        if mask_dirs:
            assert len(self.images) == len(self.masks), "Number of images and masks should be the same."

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = cv2.imread(self.images[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.mask_dirs:
            segmentation_map = cv2.imread(self.masks[idx])
            segmentation_map = cv2.cvtColor(segmentation_map, cv2.COLOR_BGR2GRAY)
        else:
            segmentation_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)  # Dummy mask

        if self.transforms:
            augmented = self.transforms(image=image, mask=segmentation_map)
            encoded_inputs = self.feature_extractor(augmented['image'], augmented['mask'], return_tensors="pt")
        else:
            encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors="pt")

        for k, v in encoded_inputs.items():
            encoded_inputs[k].squeeze_()  # remove batch dimension

        return encoded_inputs

# Augmentations and transformations
transform = aug.Compose([
    aug.Resize(IMAGE_SIZE, IMAGE_SIZE),
    aug.Flip(p=0.5)
])

image_dirs = [target_img_dir]
mask_dirs = [target_mask_dir]
feature_extractor = SegformerFeatureExtractor(size=IMAGE_SIZE, align=False, reduce_zero_label=False)

train_dataset = CustomImageSegmentationDataset(image_dirs=image_dirs, mask_dirs=mask_dirs, feature_extractor=feature_extractor, transforms=transform)
valid_dataset = CustomImageSegmentationDataset(image_dirs=image_dirs, mask_dirs=mask_dirs, feature_extractor=feature_extractor, transforms=transform, train=False)

print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(valid_dataset))

from shapely.geometry import Polygon
import numpy as np
import torch

# Define the functions for calculating PQ
def getIOU(polygon1: Polygon, polygon2: Polygon):
    intersection = polygon1.intersection(polygon2).area
    union = polygon1.union(polygon2).area
    if union == 0:
        return 0
    return intersection / union

def compute_pq(gt_polygons: list, pred_polygons: list, iou_threshold=0.5):
    matched_instances = {}
    gt_matched = np.zeros(len(gt_polygons))
    pred_matched = np.zeros(len(pred_polygons))

    for gt_idx, gt_polygon in enumerate(gt_polygons):
        best_iou = iou_threshold
        best_pred_idx = None
        for pred_idx, pred_polygon in enumerate(pred_polygons):
            try:
                iou = getIOU(gt_polygon, pred_polygon)
            except:
                iou = 0
                print('Error Polygon -> iou is 0')

            if iou == 0:
                continue

            if iou > best_iou:
                best_iou = iou
                best_pred_idx = pred_idx
        if best_pred_idx is not None:
            matched_instances[(gt_idx, best_pred_idx)] = best_iou
            gt_matched[gt_idx] = 1
            pred_matched[best_pred_idx] = 1

    sq_sum = sum(matched_instances.values())
    num_matches = len(matched_instances)
    sq = sq_sum / num_matches if num_matches else 0
    rq = num_matches / ((len(gt_polygons) + len(pred_polygons))/2.0) if (gt_polygons or pred_polygons) else 0
    pq = sq * rq

    return pq, sq, rq

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=4)

from transformers import SegformerForSemanticSegmentation

# Load the Segformer B5 model pre-trained on ADE20K
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b5", num_labels=1)

optimizer = AdamW(model.parameters(), lr=0.00006)
scheduler = StepLR(optimizer, step_size=3, gamma=0.7)  # Adjust step_size and gamma as needed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Save best model initialization
best_val_accuracy = 0.0
model_save_path = "best_model.pth"

print("Model Initialized!")

# Training loop

epochs = 10

for epoch in range(epochs):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    pbar = tqdm(train_dataloader)
    accuracies = []
    losses = []
    val_accuracies = []
    val_losses = []
    model.train()
    for idx, batch in enumerate(pbar):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()

        outputs = model(pixel_values=pixel_values, labels=labels)

        upsampled_logits = nn.functional.interpolate(outputs.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
        predicted = (upsampled_logits > 0).float()  # Threshold for single-class segmentation

        accuracy = accuracy_score(labels.detach().cpu().numpy().flatten(), predicted.detach().cpu().numpy().flatten())
        loss = outputs.loss
        accuracies.append(accuracy)
        losses.append(loss.item())
        pbar.set_postfix({'Batch': idx, 'Pixel-wise accuracy': sum(accuracies) / len(accuracies), 'Loss': sum(losses) / len(losses)})

        loss.backward()
        optimizer.step()
    scheduler.step()  # Update learning rate

    # Validation
    model.eval()
    with torch.no_grad():
        for idx, batch in enumerate(valid_dataloader):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(pixel_values=pixel_values, labels=labels)
            upsampled_logits = nn.functional.interpolate(outputs.logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
            predicted = (upsampled_logits > 0).float()  # Threshold for single-class segmentation

            accuracy = accuracy_score(labels.detach().cpu().numpy().flatten(), predicted.detach().cpu().numpy().flatten())
            val_loss = outputs.loss
            val_accuracies.append(accuracy)
            val_losses.append(val_loss.item())

    avg_val_accuracy = sum(val_accuracies) / len(val_accuracies)
    if avg_val_accuracy > best_val_accuracy:
        best_val_accuracy = avg_val_accuracy
        torch.save(model.state_dict(), model_save_path)
        print(f"Saved best model with accuracy: {best_val_accuracy}")

    print(f"Train Pixel-wise accuracy: {sum(accuracies) / len(accuracies)}"
          f" Train Loss: {sum(losses) / len(losses)}"
          f" Val Pixel-wise accuracy: {sum(val_accuracies) / len(val_accuracies)}"
          f" Val Loss: {sum(val_losses) / len(val_losses)}")

import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt

# Load the image
image = Image.open('/kaggle/input/224train/train224/images224/train_0_image_1120_0_0_1.png')

# Transform the image to tensor and add a batch dimension
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to tensor
])

image_tensor = transform(image).unsqueeze(0)  # Add batch dimension

# Load the model and the saved weights
#model = SegformerWithDropout.from_pretrained("nvidia/mit-b5", ignore_mismatched_sizes=True,num_labels=1, reshape_last_stage=True)

#model.load_state_dict(torch.load('/kaggle/input/segformer-pth/best_model.pth'))

# Move model to device (GPU or CPU)
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model.to(device)
#model.eval()

# Move image to device
image_tensor = image_tensor.to(device)

# Perform inference
with torch.no_grad():
    outputs = model(pixel_values=image_tensor)

# Get the predicted mask (the logits are raw predictions before applying softmax/sigmoid)
predicted_mask = outputs.logits.argmax(dim=1).squeeze().cpu().numpy()

# Convert the original image to a NumPy array for visualization
image_np = np.array(image)

# Resize predicted mask to original image size
predicted_mask_resized = cv2.resize(predicted_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)

# Convert the predicted mask to an RGB image for overlay
predicted_mask_rgb = np.zeros_like(image_np)
predicted_mask_rgb[:, :, 1] = (predicted_mask_resized * 255).astype(np.uint8)  # Green mask

# Overlay the mask on the original image
overlayed_image = cv2.addWeighted(image_np, 0.7, predicted_mask_rgb, 0.3, 0)

# Display the original image, predicted mask, and overlayed image side by side
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.imshow(image_np)
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(predicted_mask_resized, cmap='gray')
plt.title("Predicted Mask")
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(overlayed_image)
plt.title("Overlayed Image")
plt.axis('off')

plt.show()

The model trains. However, when doing inference it gives a blank image as the mask. The input images and masks are 224*224. Would anyone be able to help me in this matter please.

Thanks & Best Regards AMJS