bowang-lab / MedSAM

Segment Anything in Medical Images
https://www.nature.com/articles/s41467-024-44824-z
Apache License 2.0
2.99k stars 416 forks source link

Error while training #64

Closed Elsword016 closed 1 year ago

Elsword016 commented 1 year ago

I'm trying out the code, with 2D dataset as suggested in the documentation. But i'm having an runtime error, "he size of tensor a (4096) must match the size of tensor b (64) at non-singleton dimension 0". self.img_embeddings.shape=(456, 256, 64, 64), self.ori_gts.shape=(456, 256, 256) img_embed.shape=torch.Size([8, 256, 64, 64]), gt2D.shape=torch.Size([8, 1, 256, 256]), bboxes.shape=torch.Size([8, 4])

The tensor shapes seems okay.

FrexG commented 1 year ago

Did you check if you use the same model type for computing the embeddings and training? This is the print i am getting: self.img_embeddings.shape=(406, 256, 64, 64), self.ori_gts.shape=(406, 256, 256) img_embed.shape=torch.Size([8, 256, 64, 64]), gt2D.shape=torch.Size([8, 1, 256, 256]), bboxes.shape=torch.Size([8, 4])

406 on mine and 456 on yours, check if both model types are vit_b!

Elsword016 commented 1 year ago

Yeah, I am using the vit_b checkpoint. Also, 456, is the number of images in the training folder, I guess

MatousE commented 1 year ago

This happened to me too. Make sure you're using the MedSAM version of segment_anything and not the MetaAI version. As of writing this, MetaAI's version of SAM doesn't allow for batch training but MedSAM does.

Elsword016 commented 1 year ago

But MedSAM is using the same model checkpoint directly from MetaAI

MatousE commented 1 year ago

Yes the model checkpoint is the same but the code for the mask decoder in MedSAM is different. The error "The size of tensor a (4096) must match the size of tensor b (64) at non-singleton dimension 0"" is thrown by the following code in MetaAI's SAM when you attempt to do batch training:

output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings

By comparison this is the same bit of code in MedSAM which has been changed to deal with this error.

output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

# Expand per-image data in batch direction to be per-mask
if image_embeddings.shape[0] != tokens.shape[0]:
    src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
else:
    src = image_embeddings
src = src + dense_prompt_embeddings

This problem was also addressed here

Make sure you're importing the segment_anything in the MedSAM repo