Open rmokady opened 1 year ago
After run a lot of debug, I think that the repo currently don't allow inference multiple (batch) image at once, but only allow inference one image and multiple prompt at once. The repeat _interleave
must be for propagate different prompts (as token) through 1 image (as image_embeddings
).
For the code section above, If you run multiple images with no prompt, the code will run fine actually. But if I feed a batch prompt with a batch image, things will break.
In your case, do you feed data into the mask_input
? If not, the sparse_embedding
will have first dim of 1, which will work fine, otherwise, it'll break.
Hi @rmokady, @nhhung1810 ,
I was wondering if either of you have found any solution or workaround for this issue. I am experiencing the same issue with batch size 2.
My sparse embedding size is [2,2,156] My dense embedding size is [2,256,64,64]
And the same error traceback:
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 0
I'm using bbox coords as the prompts.
So my suggestion is that you don't use batch image as input, as the way they design prevent that. You can use loop instead @apoorvjain56
What if one does want to use batching? This seems like a bug.
Hi @tommiekerssies, it's just my personal experience, but i think it depend on want you want to "batch-ing", i.e
And it seem like META team support the latter instead of the former. If you want to do batching with images, I think that you can
Thanks for your response! I made batching work by simply removing the repeat interleave operations. Please note that I am using only one prompt per image.
Just curious, can you share your training task :D
Thanks for your response! I made batching work by simply removing the repeat interleave operations. Please note that I am using only one prompt per image.
Same, simply setting src = image_embeddings
works for me, I pass a batch of images to the image encoder using predictor.set_torch()
and prompt the model with a batch of points and labels using predictor.predict_torch()
.
Hi guys. I did what you write about here. It worked at first, but now as I set for some real training using more data after about 1000 iterations I get memory error. After investigating the leak I am leaning to idea that this is the problem. Don't you have memory leaks?
Nope I don't have any @Bard2803
@nhhung1810 thank you for answer.
Is this how the predict_mask
method looks like for you after modification?
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
# src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
# src = src + dense_prompt_embeddings
src = image_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# print("!!!!src_embedding shape", src.shape)
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
@Bard2803 it be better to debug if you simply can explain your task and insert input dimension, though. The way I modify it depend heavily on usecase
I found out that it was a different cause for the memory leak. I can confirm now that this approach works. Sorry for mess :o
@nhhung1810 thank you for answer.
Is this how the
predict_mask
method looks like for you after modification?def predict_masks( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Predicts masks. See 'forward' for more details.""" # Concatenate output tokens output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # Expand per-image data in batch direction to be per-mask # src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) # src = src + dense_prompt_embeddings src = image_embeddings pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) b, c, h, w = src.shape # print("!!!!src_embedding shape", src.shape) # Run the transformer hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, h, w) upscaled_embedding = self.output_upscaling(src) hyper_in_list: List[torch.Tensor] = [] for i in range(self.num_mask_tokens): hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) hyper_in = torch.stack(hyper_in_list, dim=1) b, c, h, w = upscaled_embedding.shape masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # Generate mask quality predictions iou_pred = self.iou_prediction_head(iou_token_out) return masks, iou_pred
While trying to finetune the model, I found out that simply setting src = image_embeddings
, yields a very poor training loss.
This alternative from MedSAM authors works better:
if image_embeddings.shape[0] != tokens.shape[0]:
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
else:
src = image_embeddings
src = src + dense_prompt_embeddings
After some digging, I think the reason behind this is because, during training, they initially get a batch of image embedding with shape: (N, 256, 64, 64) , and they iterate every one of it (1, 256, 64, 64) then generate different sets of prompt( num of sets = tokens.shape[0]) for current image embedding. So in mask decoder, image embedding need to repeat tokens.shape[0] times.
I get this error when batch size if larger than 1
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 0
My sparse embedding size is [2,0,156] (empty with batch size 2) My dense embedding size is [2,256,64,64] (batch size 2)
So the repeat_interleave extend the image_embeddings to 4 which is actually larger than the batch size
Am i missing something? or the repeat interleave is redundant?