med-air / 3DSAM-adapter

Holistic Adaptation of SAM from 2D to 3D for Promptable Medical Image Segmentation
170 stars 12 forks source link

problems about grid_sample #17

Open wefwefWEF2 opened 1 year ago

wefwefWEF2 commented 1 year ago

Thanks a lot for your work.

point_embedding = F.grid_sample(image_embedding, point_coord, align_corners=False).squeeze(2).squeeze(2)

I notice that if batchsize for image_embedding ( torch.Size([4, 512, 1, 20, 20])) and point_coord (torch.Size([1, 1, 1, 40, 3])) are different ,grid_sample does not work.

I did a copy operation, is that correct? So that the batchsize for point_embeddings is 4.

if image_embedding.shape[0] != point_coord.shape[0]: b, c, d, h, w = image_embedding.shape point_coord = torch.repeat_interleave(point_coord , image_embedding.shape[0], dim=0)

and in the forward of transformer, I reshape point_embeddings to 1 to add with global_query b, n, c = point_embed.shape point_embed = point_embed.reshape(1,-1,c)# Self attention block q = torch.cat([self.global_query, point_embed], dim=1)

peterant330 commented 1 year ago

oint_coord = torch.repeat_interleave(point_coord , image_embedding.shape[0], dim=0)

Hi,

I would suggest doing the following:

b, c, d, h, w = image_embedding.shape point_coord = point_coord.repeat(b, 1,1,1,1)

b, n, c = point_embed.shape q = torch.cat([self.global_query.repeat(b,1,1), point_embed], dim=1)