Open FeiYull opened 1 week ago
Hi, you need to make the following changes to migrate to the tiny version of SAM2-UNet:
super(SAM2UNet, self).__init__()
# model_cfg = "sam2_hiera_l.yaml"
model_cfg = "sam2_hiera_t.yaml"
self.rfb1 = RFB_modified(96, 64)
self.rfb2 = RFB_modified(192, 64)
self.rfb3 = RFB_modified(384, 64)
self.rfb4 = RFB_modified(768, 64)
# for param in self.encoder.parameters():
# param.requires_grad = False
# blocks = []
# for block in self.encoder.blocks:
# blocks.append(
# Adapter(block)
# )
# self.encoder.blocks = nn.Sequential(
# *blocks
# )
@xiongxyowo thks, it works
i have done the following changes:
but errors occur like
python3.9/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for SAM2Base: Missing key(s) in state_dict: "image_encoder.trunk.blocks.2.proj.weight", "image_encoder.trunk.blocks.2.proj.bias", "image_encoder.trunk.blocks.8.proj.weight", "image_encoder.trunk.blocks.8.proj.bias", "image_encoder.trunk.blocks.12.norm1.weight", "image_encoder.trunk.blocks.12.norm1.bias", "image_encoder.trunk.blocks.12.attn.qkv.weight", "image_encoder.trunk.blocks.12.attn.qkv.bias", "image_encoder.trunk.blocks.12.attn.proj.weight", "image_encoder.trunk.blocks.12.attn.proj.bias", "image_encoder.trunk.blocks.12.norm2.weight", "image_encoder.trunk.blocks.12.norm2.bias", "image_encoder.trunk.blocks.12.mlp.layers.0.weight", "
...........
Unexpected key(s) in state_dict: "image_encoder.trunk.blocks.1.proj.weight", "image_encoder.trunk.blocks.1.proj.bias", "image_encoder.trunk.blocks.3.proj.weight", "image_encoder.trunk.blocks.3.proj.bias", "image_encoder.trunk.blocks.10.proj.weight", "image_encoder.trunk.blocks.10.proj.bias". size mismatch for image_encoder.trunk.pos_embed: copying a param with shape torch.Size([1, 96, 7, 7]) from checkpoint, the shape in current model is torch.Size([1, 144, 7, 7]). size mismatch for image_encoder.trunk.pos_embed_window: copying a param with shape torch.Size([1, 96, 8, 8]) from checkpoint, the shape in current model is torch.Size([1, 144, 8, 8]). size mismatch for image_encoder.trunk.patch_embed.proj.weight: copying a param with shape torch.Size([96, 3, 7, 7]) from checkpoint, the shape in current model is torch.Size([144, 3, 7, 7]). size mismatch for image_encoder.trunk.patch_embed.proj.bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([144]).
.......