WZH0120 / SAM2-UNet

SAM2-UNet: Segment Anything 2 Makes Strong Encoder for Natural and Medical Image Segmentation
Apache License 2.0
98 stars 11 forks source link

how to train with tiny checkpoints? #17

Open FeiYull opened 1 week ago

FeiYull commented 1 week ago

i have done the following changes:

  1. https://github.com/WZH0120/SAM2-UNet/blob/eb1c38d870358cbdd769c9721062f7bb888ef9b5/train.py#L15
  2. edit the yaml https://github.com/WZH0120/SAM2-UNet/blob/eb1c38d870358cbdd769c9721062f7bb888ef9b5/SAM2UNet.py#L127

but errors occur like

python3.9/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for SAM2Base: Missing key(s) in state_dict: "image_encoder.trunk.blocks.2.proj.weight", "image_encoder.trunk.blocks.2.proj.bias", "image_encoder.trunk.blocks.8.proj.weight", "image_encoder.trunk.blocks.8.proj.bias", "image_encoder.trunk.blocks.12.norm1.weight", "image_encoder.trunk.blocks.12.norm1.bias", "image_encoder.trunk.blocks.12.attn.qkv.weight", "image_encoder.trunk.blocks.12.attn.qkv.bias", "image_encoder.trunk.blocks.12.attn.proj.weight", "image_encoder.trunk.blocks.12.attn.proj.bias", "image_encoder.trunk.blocks.12.norm2.weight", "image_encoder.trunk.blocks.12.norm2.bias", "image_encoder.trunk.blocks.12.mlp.layers.0.weight", "

...........

Unexpected key(s) in state_dict: "image_encoder.trunk.blocks.1.proj.weight", "image_encoder.trunk.blocks.1.proj.bias", "image_encoder.trunk.blocks.3.proj.weight", "image_encoder.trunk.blocks.3.proj.bias", "image_encoder.trunk.blocks.10.proj.weight", "image_encoder.trunk.blocks.10.proj.bias". size mismatch for image_encoder.trunk.pos_embed: copying a param with shape torch.Size([1, 96, 7, 7]) from checkpoint, the shape in current model is torch.Size([1, 144, 7, 7]). size mismatch for image_encoder.trunk.pos_embed_window: copying a param with shape torch.Size([1, 96, 8, 8]) from checkpoint, the shape in current model is torch.Size([1, 144, 8, 8]). size mismatch for image_encoder.trunk.patch_embed.proj.weight: copying a param with shape torch.Size([96, 3, 7, 7]) from checkpoint, the shape in current model is torch.Size([144, 3, 7, 7]). size mismatch for image_encoder.trunk.patch_embed.proj.bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([144]).

.......

xiongxyowo commented 1 week ago

Hi, you need to make the following changes to migrate to the tiny version of SAM2-UNet:

  1. Download the tiny version of pre-trained segment anything 2 from the official repository or here.
  2. Change the yaml config from "sam2_hiera_l.yaml" to "sam2_hiera_t.yaml":
    super(SAM2UNet, self).__init__()    
    # model_cfg = "sam2_hiera_l.yaml"
    model_cfg = "sam2_hiera_t.yaml"
  3. Change the input channels of RFB blocks to match the output of Hiera-Tiny:
    self.rfb1 = RFB_modified(96, 64)
    self.rfb2 = RFB_modified(192, 64)
    self.rfb3 = RFB_modified(384, 64)
    self.rfb4 = RFB_modified(768, 64)
  4. (Optional) Disable parameter-efficient fine-tuning for possible better performance:
    # for param in self.encoder.parameters():
    #     param.requires_grad = False
    # blocks = []
    # for block in self.encoder.blocks:
    #     blocks.append(
    #         Adapter(block)
    #     )
    # self.encoder.blocks = nn.Sequential(
    #     *blocks
    # )
FeiYull commented 5 days ago

@xiongxyowo thks, it works