MjdMahasneh / Simple-PyTorch-Semantic-Segmentation-CNNs

PyTorch Implementation of Semantic Segmentation CNNs: This repository features key architectures like UNet, DeepLabv3+, SegNet, FCN, and PSPNet. It's crafted to provide a solid foundation for Semantic Segmentation tasks using PyTorch.
7 stars 0 forks source link

IoU score is not Improving #3

Closed bilal6414 closed 3 months ago

bilal6414 commented 3 months ago

I am training the model for boundry extraction of satellite imagery. I have labelled dataset, I change config file put class number equal 2 , I am considering bacground as one and boundry second class. But after 100 epochs I am not getting any results IoU is 047 , for class 1 it is .94 and class 2 it is .0093 ... please suggest am using deeplab

MjdMahasneh commented 3 months ago

@bilal6414 Could you post an image with how the image and labels look like? is the label a mask of the boundaries alone (i.e., hollow mask) or is it the more common filled mask (solid mask)? I could probably help you better if you provide more context.

In general, here are some ways you could try:

Class Balance: Check if your classes are balanced. If boundaries are underrepresented, augment your dataset to balance the class distribution. you can also assign higher weights to the boundary class in the loss function to give it more importance.

Loss Function: Use a loss function that emphasizes boundaries, like Dice Loss or Boundary Loss.

Adjusting parameters (e.g., Learning Rate and Training Epochs) can also help improve the boundary extraction performance of your model.

Hope this helps.

bilal6414 commented 3 months ago

sample label and image look like this,

MjdMahasneh commented 3 months ago

looking at the data the first thing I would do is play with class weights. there is a clear severe class imbalance between foreground and background classes, therefore I suspect that if you up weight the foreground the network's performance should improve. you can do that by setting a random high value for the foreground weight, e.g., 10, and see how that impacts the performance and take it from there, but the correct way of doing it is to compute the class weights for the training subset as follows:

import os
import numpy as np
from glob import glob
from PIL import Image

def compute_class_weight(mask_folder):
    """
    Compute the class weight for the foreground pixels in a dataset of binary masks.

    Args:
    mask_folder (str): Path to the folder containing binary mask images.

    Returns:
    float: Computed class weight for the foreground pixels.
    """
    # Initialize counts
    foreground_count = 0
    total_count = 0

    # Loop through all mask images in the specified folder
    for mask_path in glob(os.path.join(mask_folder, '*.png')):
        # Load the mask image as a numpy array
        mask = np.array(Image.open(mask_path))

        # Count the number of foreground pixels (assumed to be labeled as 1)
        foreground_count += np.sum(mask)

        # Count the total number of pixels in the mask
        total_count += mask.size

    # Compute class weight: total pixels divided by twice the number of foreground pixels
    foreground_weight = total_count / (2 * foreground_count)

    return foreground_weight

# Specify the path to your mask folder
mask_folder = 'path_to_your_mask_folder'

# Compute the foreground class weight
foreground_weight = compute_class_weight(mask_folder)

# Print the computed foreground class weight
print(f"Foreground class weight: {foreground_weight}")

another approach I would also keep in mind, is testing other networks, there are works that focused on boundary segmentation, e.g., High-Resolution Network (HRNet -- maintains high-resolution representations through the network), Boundary-Aware Networks (BANet -- specifically designed for boundary-aware segmentation), ContourNet (explicitly focuses on learning contour information).

@bilal6414 let me know if you need any more assistance, but I will flag as resolved as this isn't particularly an implementation issue. Also try consulting with Torch/Reddit/Stackoverflow forums.