root0yang / BlindNet

BSD 3-Clause "New" or "Revised" License
22 stars 3 forks source link

BlindNet (CVPR 2024) : Official Project Webpage

This repository provides the official PyTorch implementation of the following paper:

Style Blind Domain Generalized Semantic Segmentation via Covariance Alignment and Semantic Consistence Contrastive Learning
Woo-Jin Ahn, Geun-Yeong Yang, Hyun-Duck Choi, Myo-Taeg Lim
Korea University, Chonnam National University

Abstract: Deep learning models for semantic segmentation often experience performance degradation when deployed to unseen target domains unidentified during the training phase. This is mainly due to variations in image texture (i.e. style) from different data sources. To tackle this challenge, existing domain generalized semantic segmentation (DGSS) methods attempt to remove style variations from the feature. However, these approaches struggle with the entanglement of style and content, which may lead to the unintentional removal of crucial content information, causing performance degradation. This study addresses this limitation by proposing BlindNet, a novel DGSS approach that blinds the style without external modules or datasets. The main idea behind our proposed approach is to alleviate the effect of style in the encoder whilst facilitating robust segmentation in the decoder. To achieve this, BlindNet comprises two key components: covariance alignment and semantic consistency contrastive learning. Specifically, the covariance alignment trains the encoder to uniformly recognize various styles and preserve the content information of the feature, rather than removing the style-sensitive factor. Meanwhile, semantic consistency contrastive learning enables the decoder to construct discriminative class embedding space and disentangles features that are vulnerable to misclassification. Through extensive experiments, our approach outperforms existing DGSS methods, exhibiting robustness and superior performance for semantic segmentation on unseen target domains.

Pytorch Implementation

Our pytorch implementation is heaviliy derived from RobustNet (CVPR 2021). If you use this code in your research, please also cite their work. [link to license]

Installation

Clone this repository.

git clone https://github.com/root0yang/BlindNet.git
cd BlindNet

Install following packages.

conda create --name blindnet python=3.6
conda activate blindnet
conda install pytorch==1.2.0 cudatoolkit==10.2
conda install scipy==1.1.0
conda install tqdm==4.46.0
conda install scikit-image==0.16.2
pip install tensorboardX==2.4
pip install thop
imageio_download_bin freeimage

How to Run BlindNet

We evaluated the model on Cityscapes, BDD-100K, Synthia (SYNTHIA-RAND-CITYSCAPES), GTAV and Mapillary Vistas.

We adopt Class uniform sampling proposed in this paper to handle class imbalance problems.

  1. For Cityscapes dataset, download "leftImg8bit_trainvaltest.zip" and "gtFine_trainvaltest.zip" from https://www.cityscapes-dataset.com/downloads/
    Unzip the files and make the directory structures as follows.

    cityscapes
    └ leftImg8bit_trainvaltest
    └ leftImg8bit
     └ train
     └ val
     └ test
    └ gtFine_trainvaltest
    └ gtFine
     └ train
     └ val
     └ test
    bdd-100k
    └ images
    └ train
    └ val
    └ test
    └ labels
    └ train
    └ val
    mapillary
    └ training
    └ images
    └ labels
    └ validation
    └ images
    └ labels
    └ test
    └ images
    └ labels
  2. We used GTAV_Split to split GTAV dataset into training/validation/test set. Please refer the txt files in split_data.

GTAV
 └ images
   └ train
     └ folder
   └ valid
     └ folder
   └ test
     └ folder
 └ labels
   └ train
     └ folder
   └ valid
     └ folder
   └ test
     └ folder
  1. We split Synthia dataset into train/val set following the RobustNet. Please refer the txt files in split_data.
synthia
 └ RGB
   └ train
   └ val
 └ GT
   └ COLOR
     └ train
     └ val
   └ LABELS
     └ train
     └ val
  1. You should modify the path in "/config.py" according to your dataset path.

    #Cityscapes Dir Location
    __C.DATASET.CITYSCAPES_DIR = <YOUR_CITYSCAPES_PATH>
    #Mapillary Dataset Dir Location
    __C.DATASET.MAPILLARY_DIR = <YOUR_MAPILLARY_PATH>
    #GTAV Dataset Dir Location
    __C.DATASET.GTAV_DIR = <YOUR_GTAV_PATH>
    #BDD-100K Dataset Dir Location
    __C.DATASET.BDD_DIR = <YOUR_BDD_PATH>
    #Synthia Dataset Dir Location
    __C.DATASET.SYNTHIA_DIR = <YOUR_SYNTHIA_PATH>
  2. You can train BlindNet with the following command.

    <path_to_blindnet>$ CUDA_VISIBLE_DEVICES=0,1 ./scripts/blindnet_train_r50os16_gtav.sh
  3. You can download Our ResNet-50 model trained with GTAV at Google Drive and validate pretrained model with the following command

    <path_to_blindnet>$ CUDA_VISIBLE_DEVICES=0,1 ./scripts/blindnet_valid_r50os16_gtav.sh <weight_file_location>
  4. You can infer the segmentation results from images through pretrained model with following commands.

    <path_to_blindnet>$ CUDA_VISIBLE_DEVICES=0,1 ./scripts/blindnet_infer_r50os16.sh <weight_file_location> <result_save_location>

Citation

If you find this work useful in your research, please cite our paper:

@article{ahn2024style,
  title={Style Blind Domain Generalized Semantic Segmentation via Covariance Alignment and Semantic Consistence Contrastive Learning},
  author={Ahn, Woo-Jin and Yang, Geun-Yeong and Choi, Hyun-Duck and Lim, Myo-Taeg},
  journal={arXiv preprint arXiv:2403.06122},
  year={2024}
}

Terms of Use

This software is for non-commercial use only. The source code is released under the Attribution-NonCommercial-ShareAlike (CC BY-NC-SA) Licence (see this for details)