NVlabs / SegFormer

Official PyTorch implementation of SegFormer
https://arxiv.org/abs/2105.15203
Other
2.57k stars 356 forks source link

SegFormer for binary semantic segmentation problem #75

Open loucif01 opened 2 years ago

loucif01 commented 2 years ago

Hello, I want to know if we can adapt SegFormer for a binary semantic segmentation problem, if yes how?

Buling-Knight commented 2 years ago

@loucif01 @tmbdev I'm sorry to bother you. have you get an answer? I met the same question when I try to adapt SegFormer for a binary semantic segmentation problem. When I trained on my dataset, the loss value decreases from positive to negative. It looks like: 2022-09-11 20:44:17,896 - mmseg - INFO - Iter [800/20000] lr: 3.068e-05, eta: 0:43:05, time: 0.205, data_time: 0.143, memory: 526, decode.loss_seg: 0.0028, decode.acc_seg: 43.4279, loss: 0.0028 2022-09-11 20:44:20,957 - mmseg - INFO - Iter [850/20000] lr: 3.252e-05, eta: 0:41:35, time: 0.061, data_time: 0.001, memory: 526, decode.loss_seg: 0.0008, decode.acc_seg: 38.4912, loss: 0.0008 2022-09-11 20:44:31,272 - mmseg - INFO - Iter [900/20000] lr: 3.434e-05, eta: 0:42:49, time: 0.206, data_time: 0.141, memory: 526, decode.loss_seg: 0.0019, decode.acc_seg: 40.7643, loss: 0.0019 2022-09-11 20:44:34,357 - mmseg - INFO - Iter [950/20000] lr: 3.616e-05, eta: 0:41:30, time: 0.062, data_time: 0.002, memory: 526, decode.loss_seg: 0.0023, decode.acc_seg: 41.2288, loss: 0.0023 2022-09-11 20:44:44,668 - mmseg - INFO - Exp name: mysegformer.b0.1024x1024.city.160k.py 2022-09-11 20:44:44,668 - mmseg - INFO - Iter [1000/20000] lr: 3.796e-05, eta: 0:42:35, time: 0.206, data_time: 0.141, memory: 526, decode.loss_seg: -0.0001, decode.acc_seg: 40.1294, loss: -0.0001 2022-09-11 20:44:47,849 - mmseg - INFO - Iter [1050/20000] lr: 3.976e-05, eta: 0:41:24, time: 0.064, data_time: 0.002, memory: 526, decode.loss_seg: 0.0005, decode.acc_seg: 40.7755, loss: 0.0005 2022-09-11 20:44:57,892 - mmseg - INFO - Iter [1100/20000] lr: 4.154e-05, eta: 0:42:17, time: 0.201, data_time: 0.139, memory: 526, decode.loss_seg: 0.0021, decode.acc_seg: 43.6932, loss: 0.0021 2022-09-11 20:45:00,815 - mmseg - INFO - Iter [1150/20000] lr: 4.332e-05, eta: 0:41:09, time: 0.058, data_time: 0.001, memory: 526, decode.loss_seg: 0.0017, decode.acc_seg: 42.2974, loss: 0.0017

I wonder is it normal? please guide me!

chefkrym commented 1 year ago

Hey @loucif01 @Buling-Knight could you be so kind as to share your codes for binary segmentation please? Thank you.

mohdsaqibxa commented 1 year ago

I am getting the same error. Loss is decreasing from positive to negative.

mohdsaqibxa commented 1 year ago

@chefkrym Sharing the complete code here

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from datasets import load_metric
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import numpy as np
import random

from tqdm import tqdm

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import warnings
warnings.filterwarnings("ignore")

class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir, damage_type, feature_extractor):
        """
        Args:
            root_dir (string): Root directory of the dataset containing the images + annotations.
            feature_extractor (SegFormerFeatureExtractor): feature extractor to prepare images + segmentation maps.
            train (bool): Whether to load "training" or "validation" images + annotations.
        """
        self.root_dir = root_dir
        self.damage_type = damage_type

        self.feature_extractor = feature_extractor

#         self.id2label = {'0': ' background\n', '2': ' object'}
#         self.id2label = {0: ' background\n', 1: ' object'}
        self.id2label = {0: ' object'}

        self.images = []
        self.masks = []

        for image_name in os.listdir(os.path.join(self.root_dir, "images"))[:150]:

            self.images.append(os.path.join(self.root_dir, "images", image_name))
            self.masks.append(os.path.join(self.root_dir, self.damage_type, image_name))

#         self.images = sorted(image_file_names)
#         self.masks = sorted(mask_file_names)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):

        image = Image.open(self.images[idx])
        segmentation_map = Image.open(self.masks[idx]).convert('L')

        # randomly crop + pad both image and segmentation map to same size
        encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
            encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs

class SegformerFinetuner(pl.LightningModule):

    def __init__(self, id2label, train_dataloader=None, val_dataloader=None, test_dataloader=None, metrics_interval=100):
        super(SegformerFinetuner, self).__init__()
        self.id2label = id2label
        self.metrics_interval = metrics_interval
        self.train_dl = train_dataloader
        self.val_dl = val_dataloader
        self.test_dl = test_dataloader

        self.num_classes = len(id2label.keys())
        self.label2id = {v:k for k,v in self.id2label.items()}

        self.model = SegformerForSemanticSegmentation.from_pretrained(
            "nvidia/segformer-b0-finetuned-ade-512-512", 
            return_dict=False, 
            num_labels=self.num_classes,
            id2label=self.id2label,
            label2id=self.label2id,
            ignore_mismatched_sizes=True,
        )

        self.train_mean_iou = load_metric("mean_iou")
        self.val_mean_iou = load_metric("mean_iou")
        self.test_mean_iou = load_metric("mean_iou")

    def forward(self, images, masks):
        outputs = self.model(pixel_values=images, labels=masks)
        return(outputs)

    def training_step(self, batch, batch_nb):

        images, masks = batch['pixel_values'], batch['labels']

        outputs = self(images, masks)

        loss, logits = outputs[0], outputs[1]

        upsampled_logits = nn.functional.interpolate(
            logits, 
            size=masks.shape[-2:], 
            mode="bilinear", 
            align_corners=False
        )

        predicted = upsampled_logits.argmax(dim=1)

        self.train_mean_iou.add_batch(
            predictions=predicted.detach().cpu().numpy(), 
            references=masks.detach().cpu().numpy()
        )
        if batch_nb % self.metrics_interval == 0:

            metrics = self.train_mean_iou.compute(
                num_labels=self.num_classes, 
                ignore_index=255, 
                reduce_labels=False,
            )

            metrics = {'loss': loss, "mean_iou": metrics["mean_iou"], "mean_accuracy": metrics["mean_accuracy"]}

            for k,v in metrics.items():
                self.log(k,v)

            return(metrics)
        else:
            return({'loss': loss})

    def validation_step(self, batch, batch_nb):

        images, masks = batch['pixel_values'], batch['labels']

        outputs = self(images, masks)

        loss, logits = outputs[0], outputs[1]

        upsampled_logits = nn.functional.interpolate(
            logits, 
            size=masks.shape[-2:], 
            mode="bilinear", 
            align_corners=False
        )

        predicted = upsampled_logits.argmax(dim=1)

        self.val_mean_iou.add_batch(
            predictions=predicted.detach().cpu().numpy(), 
            references=masks.detach().cpu().numpy()
        )

        return({'val_loss': loss})

    def validation_epoch_end(self, outputs):
        metrics = self.val_mean_iou.compute(
              num_labels=self.num_classes, 
              ignore_index=255, 
              reduce_labels=False,
          )

        avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        val_mean_iou = metrics["mean_iou"]
        val_mean_accuracy = metrics["mean_accuracy"]

        metrics = {"val_loss": avg_val_loss, "val_mean_iou":val_mean_iou, "val_mean_accuracy":val_mean_accuracy}
        for k,v in metrics.items():
            self.log(k,v)

        return metrics

    def test_step(self, batch, batch_nb):

        images, masks = batch['pixel_values'], batch['labels']

        outputs = self(images, masks)

        loss, logits = outputs[0], outputs[1]

        upsampled_logits = nn.functional.interpolate(
            logits, 
            size=masks.shape[-2:], 
            mode="bilinear", 
            align_corners=False
        )

        predicted = upsampled_logits.argmax(dim=1)

        self.test_mean_iou.add_batch(
            predictions=predicted.detach().cpu().numpy(), 
            references=masks.detach().cpu().numpy()
        )

        return({'test_loss': loss})

    def test_epoch_end(self, outputs):
        metrics = self.test_mean_iou.compute(
              num_labels=self.num_classes, 
              ignore_index=255, 
              reduce_labels=False,
          )

        avg_test_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        test_mean_iou = metrics["mean_iou"]
        test_mean_accuracy = metrics["mean_accuracy"]

        metrics = {"test_loss": avg_test_loss, "test_mean_iou":test_mean_iou, "test_mean_accuracy":test_mean_accuracy}

        for k,v in metrics.items():
            self.log(k,v)

        return metrics

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)

    def train_dataloader(self):
        return self.train_dl

    def val_dataloader(self):
        return self.val_dl

    def test_dataloader(self):
        return self.test_dl  

feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
feature_extractor.do_reduce_labels = False
feature_extractor.size = 128

train_path = "cropped_data_7475_dents/train"
valid_path = test_path = "cropped_data_7475_dents/test"
damage_type = "dents"

train_dataset = SemanticSegmentationDataset(train_path, damage_type, feature_extractor)
val_dataset = SemanticSegmentationDataset(valid_path, damage_type, feature_extractor)
test_dataset = SemanticSegmentationDataset(test_path, damage_type, feature_extractor)

batch_size = 8
num_workers = 2
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)

segformer_finetuner = SegformerFinetuner(
    train_dataset.id2label, 
    train_dataloader=train_dataloader, 
    val_dataloader=val_dataloader, 
    test_dataloader=test_dataloader, 
    metrics_interval=5,
)

early_stop_callback = EarlyStopping(
    monitor="val_loss", 
    min_delta=0.00, 
    patience=10, 
    verbose=False, 
    mode="min",
)

checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="val_loss")

trainer = pl.Trainer(
    gpus=1, 
    callbacks=[early_stop_callback, checkpoint_callback],
    max_epochs=10,
    val_check_interval=len(train_dataloader),
)
trainer.fit(segformer_finetuner)
aayushshah27894 commented 1 year ago

facing the same issue