Open wefwefWEF2 opened 1 year ago
oint_coord = torch.repeat_interleave(point_coord , image_embedding.shape[0], dim=0)
Hi,
I would suggest doing the following:
b, c, d, h, w = image_embedding.shape point_coord = point_coord.repeat(b, 1,1,1,1)
b, n, c = point_embed.shape q = torch.cat([self.global_query.repeat(b,1,1), point_embed], dim=1)
Thanks a lot for your work.
point_embedding = F.grid_sample(image_embedding, point_coord, align_corners=False).squeeze(2).squeeze(2)
I notice that if batchsize for image_embedding ( torch.Size([4, 512, 1, 20, 20])) and point_coord (torch.Size([1, 1, 1, 40, 3])) are different ,grid_sample does not work.
I did a copy operation, is that correct? So that the batchsize for point_embeddings is 4.
if image_embedding.shape[0] != point_coord.shape[0]: b, c, d, h, w = image_embedding.shape point_coord = torch.repeat_interleave(point_coord , image_embedding.shape[0], dim=0)
and in the forward of transformer, I reshape point_embeddings to 1 to add with global_query b, n, c = point_embed.shape point_embed = point_embed.reshape(1,-1,c)# Self attention block q = torch.cat([self.global_query, point_embed], dim=1)