facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47k stars 5.56k forks source link

RuntimeError when using batch size > 1 #277

Open rmokady opened 1 year ago

rmokady commented 1 year ago

I get this error when batch size if larger than 1

RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 0

My sparse embedding size is [2,0,156] (empty with batch size 2) My dense embedding size is [2,256,64,64] (batch size 2)

output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings

So the repeat_interleave extend the image_embeddings to 4 which is actually larger than the batch size

Am i missing something? or the repeat interleave is redundant?

nhhung1810 commented 1 year ago

After run a lot of debug, I think that the repo currently don't allow inference multiple (batch) image at once, but only allow inference one image and multiple prompt at once. The repeat _interleave must be for propagate different prompts (as token) through 1 image (as image_embeddings).

For the code section above, If you run multiple images with no prompt, the code will run fine actually. But if I feed a batch prompt with a batch image, things will break.

In your case, do you feed data into the mask_input ? If not, the sparse_embedding will have first dim of 1, which will work fine, otherwise, it'll break.

apoorvjain56 commented 1 year ago

Hi @rmokady, @nhhung1810 ,

I was wondering if either of you have found any solution or workaround for this issue. I am experiencing the same issue with batch size 2.

My sparse embedding size is [2,2,156] My dense embedding size is [2,256,64,64]

And the same error traceback:

RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 0

I'm using bbox coords as the prompts.

nhhung1810 commented 1 year ago

So my suggestion is that you don't use batch image as input, as the way they design prevent that. You can use loop instead @apoorvjain56

tommiekerssies commented 1 year ago

What if one does want to use batching? This seems like a bug.

nhhung1810 commented 1 year ago

Hi @tommiekerssies, it's just my personal experience, but i think it depend on want you want to "batch-ing", i.e

And it seem like META team support the latter instead of the former. If you want to do batching with images, I think that you can

tommiekerssies commented 1 year ago

Thanks for your response! I made batching work by simply removing the repeat interleave operations. Please note that I am using only one prompt per image.

nhhung1810 commented 1 year ago

Just curious, can you share your training task :D

SalmaG98 commented 1 year ago

Thanks for your response! I made batching work by simply removing the repeat interleave operations. Please note that I am using only one prompt per image.

Same, simply setting src = image_embeddings works for me, I pass a batch of images to the image encoder using predictor.set_torch() and prompt the model with a batch of points and labels using predictor.predict_torch().

DraBard commented 1 year ago

Hi guys. I did what you write about here. It worked at first, but now as I set for some real training using more data after about 1000 iterations I get memory error. After investigating the leak I am leaning to idea that this is the problem. Don't you have memory leaks?

nhhung1810 commented 1 year ago

Nope I don't have any @Bard2803

DraBard commented 1 year ago

@nhhung1810 thank you for answer.

Is this how the predict_mask method looks like for you after modification?

def predict_masks(
    self,
    image_embeddings: torch.Tensor,
    image_pe: torch.Tensor,
    sparse_prompt_embeddings: torch.Tensor,
    dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Predicts masks. See 'forward' for more details."""
    # Concatenate output tokens
    output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
    output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
    tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

    # Expand per-image data in batch direction to be per-mask
    # src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
    # src = src + dense_prompt_embeddings

    src = image_embeddings
    pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
    b, c, h, w = src.shape

    # print("!!!!src_embedding shape", src.shape)

    # Run the transformer
    hs, src = self.transformer(src, pos_src, tokens)
    iou_token_out = hs[:, 0, :]
    mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]

    # Upscale mask embeddings and predict masks using the mask tokens
    src = src.transpose(1, 2).view(b, c, h, w)
    upscaled_embedding = self.output_upscaling(src)
    hyper_in_list: List[torch.Tensor] = []
    for i in range(self.num_mask_tokens):
        hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
    hyper_in = torch.stack(hyper_in_list, dim=1)
    b, c, h, w = upscaled_embedding.shape
    masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

    # Generate mask quality predictions
    iou_pred = self.iou_prediction_head(iou_token_out)

    return masks, iou_pred
nhhung1810 commented 1 year ago

@Bard2803 it be better to debug if you simply can explain your task and insert input dimension, though. The way I modify it depend heavily on usecase

DraBard commented 1 year ago

I found out that it was a different cause for the memory leak. I can confirm now that this approach works. Sorry for mess :o

emmanuelCN commented 7 months ago

@nhhung1810 thank you for answer.

Is this how the predict_mask method looks like for you after modification?

def predict_masks(
    self,
    image_embeddings: torch.Tensor,
    image_pe: torch.Tensor,
    sparse_prompt_embeddings: torch.Tensor,
    dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Predicts masks. See 'forward' for more details."""
    # Concatenate output tokens
    output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
    output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
    tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

    # Expand per-image data in batch direction to be per-mask
    # src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
    # src = src + dense_prompt_embeddings

    src = image_embeddings
    pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
    b, c, h, w = src.shape

    # print("!!!!src_embedding shape", src.shape)

    # Run the transformer
    hs, src = self.transformer(src, pos_src, tokens)
    iou_token_out = hs[:, 0, :]
    mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]

    # Upscale mask embeddings and predict masks using the mask tokens
    src = src.transpose(1, 2).view(b, c, h, w)
    upscaled_embedding = self.output_upscaling(src)
    hyper_in_list: List[torch.Tensor] = []
    for i in range(self.num_mask_tokens):
        hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
    hyper_in = torch.stack(hyper_in_list, dim=1)
    b, c, h, w = upscaled_embedding.shape
    masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

    # Generate mask quality predictions
    iou_pred = self.iou_prediction_head(iou_token_out)

    return masks, iou_pred

While trying to finetune the model, I found out that simply setting src = image_embeddings, yields a very poor training loss. This alternative from MedSAM authors works better:

if image_embeddings.shape[0] != tokens.shape[0]:
    src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
else:
    src = image_embeddings
src = src + dense_prompt_embeddings
DarrenZhaoFR commented 4 months ago

After some digging, I think the reason behind this is because, during training, they initially get a batch of image embedding with shape: (N, 256, 64, 64) , and they iterate every one of it (1, 256, 64, 64) then generate different sets of prompt( num of sets = tokens.shape[0]) for current image embedding. So in mask decoder, image embedding need to repeat tokens.shape[0] times.