wangsssky / SonarSAM

Segment Anything Model, SAM, Sonar images
Apache License 2.0
54 stars 5 forks source link

When using MobileSAM, Target size (torch.Size([4, 12, 1024, 1024])) must be the same as input size (torch.Size([4, 3, 1024, 1024])) #6

Closed theneao closed 1 year ago

theneao commented 1 year ago

When using MobileSAM, only was modified SAM NAME and SAM CHECKPOINT, report the following error, do not understand the reason

  File "train_SAM.py", line 238, in <module>
    main()
  File "train_SAM.py", line 142, in main
    loss, outputs = net.forward(image, mask, boxes=boxes)
  File "/root/SonarSAM/model/model_proxy_SAM.py", line 377, in forward
    bce_loss = self.bcewithlogit(input=pred_masks, target=masks)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 720, in forward
    return F.binary_cross_entropy_with_logits(input, target,
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 3163, in binary_cross_entropy_with_logits
    raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
ValueError: Target size (torch.Size([4, 12, 1024, 1024])) must be the same as input size (torch.Size([4, 3, 1024, 1024]))
wangsssky commented 1 year ago

Well, this problem may be related to the class number settings. Make sure the OUTPUT_CHN in the config file has been modified to fit your data.

theneao commented 1 year ago

The dataset is your default Watertank segmentation dataset, The vit_ B model can run successfully. Running Configuration and of MobileSAM same as vitb, except for SAM NAME and SAM CHECKPOINT.OUTPUT CHN remains at 12. Attempted OUTPUT_ CHN=3 still reports an error

theneao commented 1 year ago
---

DATA_PATH: './watertank-segmentation'
IMAGE_LIST_PATH: './dataset/marine_debris'
RANDOM_SEED: 42

MODEL_DIR: './saves/'
MODEL_NAME: 'mobile'
# huge
# SAM_NAME: 'vit_h'
# SAM_CHECKPOINT: '/path/to/sam_vit_h_4b8939.pth'
# large
# SAM_NAME: 'vit_l'
# SAM_CHECKPOINT: '/path/to/sam_vit_l_0b3195.pth'
# big
# SAM_NAME: 'vit_b'
# SAM_CHECKPOINT: './pretrained/sam_vit_b_01ec64.pth'
# mobile 
SAM_NAME: 'mobile'
SAM_CHECKPOINT: './pretrained/mobile_sam.pt'

IS_FINETUNE_IMAGE_ENCODER: False
USE_ADAPTATION: True
ADAPTATION_TYPE: 'LORA' # 'LORA', 'learnable_prompt_layer'
HEAD_TYPE: 'semantic_mask_decoder_LORA' # 'custom' 'semantic_mask_decoder'

EPOCH_NUM: 30
RESUME_FROM: 0

TRAIN_BATCHSIZE: 4
VAL_BATCHSIZE: 1 # fixed

OPTIMIZER: 'ADAM'
WEIGHT_DECAY: 0.00005
MOMENTUM: 0.9
LEARNING_RATE: 0.0003
WARM_LEN: 1

INPUT_SIZE: 1024
OUTPUT_CHN: 12
EVAL_METRIC: 'DICE'

PRT_LOSS: False
VISUALIZE: False
wangsssky commented 1 year ago

The dataset is your default Watertank segmentation dataset, The vit_ B model can run successfully. Running Configuration and of MobileSAM same as vitb, except for SAM NAME and SAM CHECKPOINT.OUTPUT CHN remains at 12. Attempted OUTPUT_ CHN=3 still reports an error

I see. It is a bug that I didn't pass the num_multimask_outputs to the mobile SAM model. I updated the repo and hope it can work now.

theneao commented 1 year ago

It's also wrong, but the direction is right

  File "/root/SonarSAM/model/model_proxy_SAM.py", line 146, in __init__
    self.sam = build_sam_mobile(checkpoint=checkpoint, num_multimask_outputs=num_classes)
TypeError: setup_model() got an unexpected keyword argument 'num_multimask_outputs'

Just because 'setup_model (checkpoint=None)' does not have this parameter

            mask_decoder=MaskDecoder(
                    num_multimask_outputs=12,
                    transformer=TwoWayTransformer(
                    depth=2,
                    embedding_dim=prompt_embed_dim,
                    mlp_dim=2048,
                    num_heads=8,
                )

I directly modified this parameter in MaskDecodernum_multimask_outputs=12

theneao commented 1 year ago

Some parts of dataloader.py can also cause such problems

        for id in range(12):
            m = (mask == id).astype('uint8')
            # masks.append(np.asarray(mask[id]))
            masks.append(m)
        masks = np.stack(masks, axis=0)

Careful observation is required I changed to

        for id in range(len(class_to_id)):
            m = (mask == id).astype('uint8')
            # masks.append(np.asarray(mask[id]))
            masks.append(m)
        masks = np.stack(masks, axis=0)