XuJiacong / PIDNet

This is the official repository for our recent work: PIDNet
MIT License
601 stars 110 forks source link

segmeantation just road lane #27

Open yeyan00 opened 2 years ago

yeyan00 commented 2 years ago

i just want segmeantation lane,other as background in the training, backgroud as class 0, lane as class 1, train 200 epoch,but result is very bad

predict image

src image image

ground trues image

XuJiacong commented 2 years ago

Pls check if the generated boundary is correct

MrL-CV commented 1 year ago

Pls check if the generated boundary is correct

do you mean check the generated boundary on ground truth label files?

cuteboyqq commented 1 year ago

l suggust you do not do multi scale aug and rand crop, becuase your dataset just focus on the drivable area, the rand crop will crop the sky only, and multi scale aug will do no help to your lane line training (l had trained laneline and mainlane, alterlane before, the miou is low, but when l set multi_scale=False and city=False , and no rand crop, the miou gain to about 62% at 50 epoch, train image size is 512x256, number of train datset :10000, number of validation dataset: 2000) below image is the inference result of 6 class semantic segmenation (This is not pidnet, we develop all new network! But pidnet is better too , you can just using pident without modify network) messageImage_1686582991110

the code is at dataset/base_dataset.py please set multi_scale=False and city=False image

multi_scale is set at file config.yaml : MULTI_SCALE: false image

cuteboyqq commented 1 year ago

sorry, you also need to modified code if you set MULTI_SCALE false below is base_dataset.py modified code

# ------------------------------------------------------------------------------
# Modified based on https://github.com/HRNet/HRNet-Semantic-Segmentation
# ------------------------------------------------------------------------------
import cv2
import numpy as np
import random

from torch.nn import functional as F
from torch.utils import data

y_k_size = 6
x_k_size = 6
class BaseDataset(data.Dataset):
    def __init__(self,
                 ignore_label=255,
                 base_size=2048,
                 crop_size=(512, 1024),
                 scale_factor=16,
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]):

        self.base_size = base_size
        self.crop_size = crop_size
        self.ignore_label = ignore_label
        self.mean = mean
        self.std = std
        self.scale_factor = scale_factor
        self.files = []
        self.size = (512,256)
    def __len__(self):
        return len(self.files)

    def input_transform(self, image, city=True):
        image = cv2.resize(image, self.size, interpolation=cv2.INTER_LINEAR)
        image = image.astype(np.float32)
        image = image / 255.0
        image -= self.mean
        image /= self.std
        return image

    def label_transform(self, label):
        label = cv2.resize(label, self.size, interpolation=cv2.INTER_NEAREST)
        return np.array(label).astype(np.uint8)

    def edge_transform(self, label):
        edge = cv2.resize(edge, self.size, interpolation=cv2.INTER_NEAREST)
        return edge

    def pad_image(self, image, h, w, size, padvalue):
        pad_image = image.copy()
        pad_h = max(size[0] - h, 0)
        pad_w = max(size[1] - w, 0)
        if pad_h > 0 or pad_w > 0:
            pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0,
                                           pad_w, cv2.BORDER_CONSTANT,
                                           value=padvalue)

        return pad_image

    def rand_crop(self, image, label, edge):
        h, w = image.shape[:-1]
        image = self.pad_image(image, h, w, self.crop_size,
                               (0.0, 0.0, 0.0))
        label = self.pad_image(label, h, w, self.crop_size,
                               (self.ignore_label,))
        edge = self.pad_image(edge, h, w, self.crop_size,
                               (0.0,))

        new_h, new_w = label.shape
        x = random.randint(0, new_w - self.crop_size[1])
        y = random.randint(0, new_h - self.crop_size[0])
        image = image[y:y+self.crop_size[0], x:x+self.crop_size[1]]
        label = label[y:y+self.crop_size[0], x:x+self.crop_size[1]]
        edge = edge[y:y+self.crop_size[0], x:x+self.crop_size[1]]

        return image, label, edge

    def multi_scale_aug(self, image, label=None, edge=None,
                        rand_scale=1, rand_crop=True):
        long_size = np.int(self.base_size * rand_scale + 0.5)
        h, w = image.shape[:2]
        if h > w:
            new_h = long_size
            new_w = np.int(w * long_size / h + 0.5)
        else:
            new_w = long_size
            new_h = np.int(h * long_size / w + 0.5)

        image = cv2.resize(image, (new_w, new_h),
                           interpolation=cv2.INTER_LINEAR)
        if label is not None:
            label = cv2.resize(label, (new_w, new_h),
                               interpolation=cv2.INTER_NEAREST)
            if edge is not None:
                edge = cv2.resize(edge, (new_w, new_h),
                                   interpolation=cv2.INTER_NEAREST)
        else:
            return image

        if rand_crop:
            image, label, edge = self.rand_crop(image, label, edge)

        return image, label, edge

    def gen_sample(self, image, label,
                   multi_scale=False, is_flip=True, edge_pad=True, edge_size=4, city=False):

        edge = cv2.Canny(label, 0.1, 0.2)
        kernel = np.ones((edge_size, edge_size), np.uint8)
        if edge_pad:
            edge = edge[y_k_size:-y_k_size, x_k_size:-x_k_size]
            edge = np.pad(edge, ((y_k_size,y_k_size),(x_k_size,x_k_size)), mode='constant')
        edge = (cv2.dilate(edge, kernel, iterations=1)>50)*1.0

        # if multi_scale:
        #     rand_scale = 0.5 + random.randint(0, self.scale_factor) / 10.0
        #     image, label, edge = self.multi_scale_aug(image, label, edge,
        #                                         rand_scale=rand_scale)

        image = self.input_transform(image, city=city)
        label = self.label_transform(label)
        edge = self.label_transform(edge)

        image = image.transpose((2, 0, 1))

        if is_flip:
            flip = np.random.choice(2) * 2 - 1
            image = image[:, :, ::flip]
            label = label[:, ::flip]
            edge = edge[:, ::flip]

        return image, label, edge

    def inference(self, config, model, image):
        size = image.size()
        pred = model(image)

        if config.MODEL.NUM_OUTPUTS > 1:
            pred = pred[config.TEST.OUTPUT_INDEX]

        pred = F.interpolate(
            input=pred, size=size[-2:],
            mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
        )

        return pred.exp()