Closed Elsword016 closed 1 year ago
Did you check if you use the same model type for computing the embeddings and training?
This is the print i am getting:
self.img_embeddings.shape=(406, 256, 64, 64), self.ori_gts.shape=(406, 256, 256) img_embed.shape=torch.Size([8, 256, 64, 64]), gt2D.shape=torch.Size([8, 1, 256, 256]), bboxes.shape=torch.Size([8, 4])
406
on mine and 456
on yours, check if both model types are vit_b
!
Yeah, I am using the vit_b
checkpoint. Also, 456, is the number of images in the training folder, I guess
This happened to me too. Make sure you're using the MedSAM version of segment_anything and not the MetaAI version. As of writing this, MetaAI's version of SAM doesn't allow for batch training but MedSAM does.
But MedSAM is using the same model checkpoint directly from MetaAI
Yes the model checkpoint is the same but the code for the mask decoder in MedSAM is different. The error "The size of tensor a (4096) must match the size of tensor b (64) at non-singleton dimension 0"" is thrown by the following code in MetaAI's SAM when you attempt to do batch training:
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
By comparison this is the same bit of code in MedSAM which has been changed to deal with this error.
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
if image_embeddings.shape[0] != tokens.shape[0]:
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
else:
src = image_embeddings
src = src + dense_prompt_embeddings
This problem was also addressed here
Make sure you're importing the segment_anything in the MedSAM repo
I'm trying out the code, with 2D dataset as suggested in the documentation. But i'm having an runtime error, "he size of tensor a (4096) must match the size of tensor b (64) at non-singleton dimension 0". self.img_embeddings.shape=(456, 256, 64, 64), self.ori_gts.shape=(456, 256, 256) img_embed.shape=torch.Size([8, 256, 64, 64]), gt2D.shape=torch.Size([8, 1, 256, 256]), bboxes.shape=torch.Size([8, 4])
The tensor shapes seems okay.