Open AlexanderZeilmann opened 3 months ago
You could use infer.sh, and modify the parameter as you want. vp means visualization path you want to output, tdp mean the image you want to used for segmenting.
python validation.py --seed 2023 \ -vp ./results/vis_sam_med3d \ -tdp data/initial_test_dataset/total_segment -nc 1
Regarding this issue, I have my data set in the specified format and have saved the FastSam3D checkpoint locally. When I run infer.sh I experience errors when loading the model. They are regarding missing encoder blocks. Additionally, when looking through the validation.py file I notice many functions refer to tuning the model. In my case, I do not want to tune the model. I want to test the model's segmentation on my data. Any help with this would be greatly appreciated.
Here is my error message below.
RuntimeError: Error(s) in loading state_dict for Sam3D: Missing key(s) in state_dict: "image_encoder.blocks.0.norm1.weight", "image_encoder.blocks.0.norm1.bias", "image_encoder.blocks.0.attn.rel_pos_d", "image_encoder.blocks.0.attn.rel_pos_h", "image_encoder.blocks.0.attn.rel_pos_w", "image_encoder.blocks.0.attn.qkv.weight", "image_encoder.blocks.0.attn.qkv.bias", "image_encoder.blocks.0.attn.proj.weight", "image_encoder.blocks.0.attn.proj.bias", "image_encoder.blocks.1.norm1.weight", "image_encoder.blocks.1.norm1.bias", "image_encoder.blocks.1.attn.rel_pos_d", "image_encoder.blocks.1.attn.rel_pos_h", "image_encoder.blocks.1.attn.rel_pos_w", "image_encoder.blocks.1.attn.qkv.weight", "image_encoder.blocks.1.attn.qkv.bias", "image_encoder.blocks.1.attn.proj.weight", "image_encoder.blocks.1.attn.proj.bias", "image_encoder.blocks.6.norm1.weight", "image_encoder.blocks.6.norm1.bias", "image_encoder.blocks.6.attn.rel_pos_d", "image_encoder.blocks.6.attn.rel_pos_h", "image_encoder.blocks.6.attn.rel_pos_w", "image_encoder.blocks.6.attn.qkv.weight", "image_encoder.blocks.6.attn.qkv.bias", "image_encoder.blocks.6.attn.proj.weight", "image_encoder.blocks.6.attn.proj.bias", "image_encoder.blocks.6.norm2.weight", "image_encoder.blocks.6.norm2.bias", "image_encoder.blocks.6.mlp.lin1.weight", "image_encoder.blocks.6.mlp.lin1.bias", "image_encoder.blocks.6.mlp.lin2.weight", "image_encoder.blocks.6.mlp.lin2.bias", "image_encoder.blocks.7.norm1.weight", "image_encoder.blocks.7.norm1.bias", "image_encoder.blocks.7.attn.rel_pos_d", "image_encoder.blocks.7.attn.rel_pos_h", "image_encoder.blocks.7.attn.rel_pos_w", "image_encoder.blocks.7.attn.qkv.weight", "image_encoder.blocks.7.attn.qkv.bias", "image_encoder.blocks.7.attn.proj.weight", "image_encoder.blocks.7.attn.proj.bias", "image_encoder.blocks.7.norm2.weight", "image_encoder.blocks.7.norm2.bias", "image_encoder.blocks.7.mlp.lin1.weight", "image_encoder.blocks.7.mlp.lin1.bias", "image_encoder.blocks.7.mlp.lin2.weight", "image_encoder.blocks.7.mlp.lin2.bias", "image_encoder.blocks.8.norm1.weight", "image_encoder.blocks.8.norm1.bias", "image_encoder.blocks.8.attn.rel_pos_d", "image_encoder.blocks.8.attn.rel_pos_h", "image_encoder.blocks.8.attn.rel_pos_w", "image_encoder.blocks.8.attn.qkv.weight", "image_encoder.blocks.8.attn.qkv.bias", "image_encoder.blocks.8.attn.proj.weight", "image_encoder.blocks.8.attn.proj.bias", "image_encoder.blocks.8.norm2.weight", "image_encoder.blocks.8.norm2.bias", "image_encoder.blocks.8.mlp.lin1.weight", "image_encoder.blocks.8.mlp.lin1.bias", "image_encoder.blocks.8.mlp.lin2.weight", "image_encoder.blocks.8.mlp.lin2.bias", "image_encoder.blocks.9.norm1.weight", "image_encoder.blocks.9.norm1.bias", "image_encoder.blocks.9.attn.rel_pos_d", "image_encoder.blocks.9.attn.rel_pos_h", "image_encoder.blocks.9.attn.rel_pos_w", "image_encoder.blocks.9.attn.qkv.weight", "image_encoder.blocks.9.attn.qkv.bias", "image_encoder.blocks.9.attn.proj.weight", "image_encoder.blocks.9.attn.proj.bias", "image_encoder.blocks.9.norm2.weight", "image_encoder.blocks.9.norm2.bias", "image_encoder.blocks.9.mlp.lin1.weight", "image_encoder.blocks.9.mlp.lin1.bias", "image_encoder.blocks.9.mlp.lin2.weight", "image_encoder.blocks.9.mlp.lin2.bias", "image_encoder.blocks.10.norm1.weight", "image_encoder.blocks.10.norm1.bias", "image_encoder.blocks.10.attn.rel_pos_d", "image_encoder.blocks.10.attn.rel_pos_h", "image_encoder.blocks.10.attn.rel_pos_w", "image_encoder.blocks.10.attn.qkv.weight", "image_encoder.blocks.10.attn.qkv.bias", "image_encoder.blocks.10.attn.proj.weight", "image_encoder.blocks.10.attn.proj.bias", "image_encoder.blocks.10.norm2.weight", "image_encoder.blocks.10.norm2.bias", "image_encoder.blocks.10.mlp.lin1.weight", "image_encoder.blocks.10.mlp.lin1.bias", "image_encoder.blocks.10.mlp.lin2.weight", "image_encoder.blocks.10.mlp.lin2.bias", "image_encoder.blocks.11.norm1.weight", "image_encoder.blocks.11.norm1.bias", "image_encoder.blocks.11.attn.rel_pos_d", "image_encoder.blocks.11.attn.rel_pos_h", "image_encoder.blocks.11.attn.rel_pos_w", "image_encoder.blocks.11.attn.qkv.weight", "image_encoder.blocks.11.attn.qkv.bias", "image_encoder.blocks.11.attn.proj.weight", "image_encoder.blocks.11.attn.proj.bias", "image_encoder.blocks.11.norm2.weight", "image_encoder.blocks.11.norm2.bias", "image_encoder.blocks.11.mlp.lin1.weight", "image_encoder.blocks.11.mlp.lin1.bias", "image_encoder.blocks.11.mlp.lin2.weight", "image_encoder.blocks.11.mlp.lin2.bias". size mismatch for image_encoder.blocks.2.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.2.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.2.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]).
With the original segment anything all I have to do to segment a single image is downloading the checkpoint and running
How can I do something similar in FastSAM3D? I downloaded the FastSAM3D checkpoint and have a 3D image with prompts ready. How do use FastSAM3D to segment my image using my prompts?