MendelXu / SAN

Open-vocabulary Semantic Segmentation
https://mendelxu.github.io/SAN/
MIT License
295 stars 27 forks source link

Train on private dataset (only one category) #42

Closed lxr-1204 closed 7 months ago

lxr-1204 commented 7 months ago

Thank you for your work, When I used a private dataset (only one category) for training, I first faked VOC to register the dataset,

CLASS_NAMES = (
    "polyp",
)

def _get_voc_meta(cat_list):
    ret = {
        "stuff_classes": cat_list,
    }
    return ret

def register_all_voc_11k(root):
    root = os.path.join(root, "pranet")
    meta = _get_voc_meta(CLASS_NAMES)

    for name, image_dirname, sem_seg_dirname in [
        ("train", "JPEGImages", "annotations_detectron2/train"),
        ("val", "JPEGImages", "annotations_detectron2/val"),
    ]:
        image_dir = os.path.join(root, image_dirname)
        gt_dir = os.path.join(root, sem_seg_dirname)
        all_name = f"pranet_sem_seg_{name}"
        DatasetCatalog.register(
            all_name,
            lambda x=image_dir, y=gt_dir: load_sem_seg(
                y, x, gt_ext="png", image_ext="jpg"
            ),
        )
        MetadataCatalog.get(all_name).set(
            image_root=image_dir,
            sem_seg_root=gt_dir,
            evaluator_type="sem_seg",
            ignore_label=255,
            **meta,
        )

_root = os.getenv("DETECTRON2_DATASETS", "datasets")
register_all_voc_11k(_root)

and then I used the code to train. The training loss seemed normal, but when I predicted, I Found that all the results are pure white (All pixels are foreground), python train_net.py --config-file ./configs/san_clip_vit_res4_pranet.yaml --num-gpus 1 OUTPUT_DIR ./OUTPUT/vit_14 MODEL.SAN.NUM_CLASSES 1 can you provide me with some help? Or tell me where the problem might be. I would be very grateful!

MendelXu commented 7 months ago

It is caused by https://github.com/MendelXu/SAN/blob/81a9a2bd79d433292d46cfa0597caea5005e0116/san/model/san.py#L270

After the slicing, only the foreground class will be preserved. So if you only have one class, the maximum class id for each pixel will always be 0 (which is foreground in your situation).

Several possible solutions could be: