facebookresearch / ov-seg

This is the official PyTorch implementation of the paper Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP.
Other
689 stars 61 forks source link

Why can't support 'CLIP_ENSEMBLE_WEIGHT > 0' using 'R101c + CLIP-ViT-B/16' in demo.py #18

Closed Pixie8888 closed 1 year ago

Pixie8888 commented 1 year ago

Hi,

I try to run demo.py using model 'R101c + CLIP-ViT-B/16'. I modified the config file as:

MODEL:
  META_ARCHITECTURE: "OVSegDEMO"
  BACKBONE:
    NAME: "build_resnet_deeplab_backbone"
  RESNETS:
    DEPTH: 101
    STEM_TYPE: "deeplab"
    STEM_OUT_CHANNELS: 128
    STRIDE_IN_1X1: False
    OUT_FEATURES: [ "res2", "res3", "res4", "res5" ]
    # NORM: "SyncBN"
    RES5_MULTI_GRID: [ 1, 2, 4 ]
  WEIGHTS: "detectron2://DeepLab/R-103.pkl"
  PIXEL_MEAN: [123.675, 116.280, 103.530]
  PIXEL_STD: [58.395, 57.120, 57.375]
  SEM_SEG_HEAD:
    NAME: "OpenVocabMaskFormerHead"
    IN_FEATURES: [ "res2", "res3", "res4", "res5" ]
    IGNORE_VALUE: 255
    NUM_CLASSES: 171 # number of categories in training set
    EMBEDDING_DIM: 512
    EMBED_LAYERS: 2
    COMMON_STRIDE: 4 # not used, hard-coded
    LOSS_WEIGHT: 1.0
    CONVS_DIM: 256
    MASK_DIM: 256
    NORM: "GN"
  MASK_FORMER:
    TRANSFORMER_IN_FEATURE: "res5"
    DEEP_SUPERVISION: True
    NO_OBJECT_WEIGHT: 0.1
    DICE_WEIGHT: 1.0
    MASK_WEIGHT: 20.0
    HIDDEN_DIM: 256
    NUM_OBJECT_QUERIES: 100
    NHEADS: 8
    DROPOUT: 0.1
    DIM_FEEDFORWARD: 2048
    ENC_LAYERS: 0
    DEC_LAYERS: 6
    PRE_NORM: False
  CLIP_ADAPTER:
    TEXT_TEMPLATES: "vild"
    CLIP_MODEL_NAME: "ViT-B/16"
    MASK_FILL: "mean"
    MASK_EXPAND_RATIO: 1.0
    MASK_THR: 0.5 # choose the foreground objects
    MASK_MATTING: False # use soft background, default not used
    MASK_PROMPT_DEPTH: 3
    MASK_PROMPT_FWD: True # use mask prompt during forward
    REGION_RESIZED: True # resize to the input of clip, e.g., 224
    CLIP_ENSEMBLE: True # use ensemble of two classification branches
    CLIP_ENSEMBLE_WEIGHT: 0.5
DATASETS:
  TRAIN: ("coco_2017_train_stuff_sem_seg",)
  TEST: ("ade20k_sem_seg_val",)
SOLVER:
  IMS_PER_BATCH: 32
  BASE_LR: 0.00006
  MAX_ITER: 120000
  WARMUP_FACTOR: 1e-6
  WARMUP_ITERS: 1500
  WEIGHT_DECAY: 0.01
  WEIGHT_DECAY_NORM: 0.0
  WEIGHT_DECAY_EMBED: 0.0
  BACKBONE_MULTIPLIER: 1.0
  TEST_IMS_PER_BATCH: 1
  CLIP_GRADIENTS:
    ENABLED: True
    CLIP_TYPE: "full_model"
    CLIP_VALUE: 0.01
    NORM_TYPE: 2.0
INPUT:
  MIN_SIZE_TEST: 512
  MAX_SIZE_TEST: 2048
  CROP:
    ENABLED: True
    TYPE: "absolute"
    SIZE: (512, 512)
    SINGLE_CATEGORY_MAX_AREA: 1.0
  COLOR_AUG_SSD: True
  SIZE_DIVISIBILITY: 512  # used in dataset mapper
  FORMAT: "RGB"
TEST:
  EVAL_PERIOD: 5000
  AUG:
    ENABLED: False
    MIN_SIZES: [256, 384, 512, 640, 768, 896]
    MAX_SIZE: 3584
    FLIP: True
DATALOADER:
  FILTER_EMPTY_ANNOTATIONS: True
  NUM_WORKERS: 4
VERSION: 2

I changed the CLIP_ENSEMBLE_WEIGHT to 0.5. The command is

python demo.py --config-file configs/ovseg_R101c_demo.yaml --class-names 'Oculus' 'Ukulele'  --input ./resources/demo_samples/sample_03.jpeg --output ./pred --opts MODEL.WEIGHTS pretrained_model/ovseg_R101c_vitB16_ft_mpt.pth.pt

But it gets error: image

Is there anything wrong with my modification of config file? How can I solve it?

Jeff-LiangF commented 1 year ago

Hi @Pixie8888 , Sorry, there seems to be a bug in the code. The reason is that for branch self.clip_ensemble_weight > 0:, I forgot to filter the mask_pred with valid_flag. To solve this, could you add mask_pred = mask_pred[valid_flag] in L451 at L447?

Jeff-LiangF commented 1 year ago

Btw, initially, I tried different CLIP_ENSEMBLE_WEIGHT for the demo, but looks like 0 works the best. This is because CLIP branch works much better with diverse classes, while MaskFormer branch may overfit with pre-defined 171 COCO classes. In demo setting, people usually try rare classes, so the CLIP only works best for demo.

Pixie8888 commented 1 year ago

Hi @Pixie8888 , Sorry, there seems to be a bug in the code. The reason is that for branch self.clip_ensemble_weight > 0:, I forgot to filter the mask_pred with valid_flag. To solve this, could you add mask_pred = mask_pred[valid_flag] in L451 at L447?

Thanks! I will try it out.