ByungKwanLee / Full-Segment-Anything

This is Pytorch Implementation Code for adding new features in code of Segment-Anything. Here, the features support batch-input on the full-grid prompt (automatic mask generation) with post-processing: removing duplicated or small regions and holes, under flexible input image size
MIT License
140 stars 9 forks source link

shape mis match in the mask decoder #8

Open saimohan16 opened 3 months ago

saimohan16 commented 3 months ago

Hey,

I have passed a random tensor of size (128,128) with few prompt tensors to see what the output shape would be like from the model, when I passed tow images (some random number) I am having an error that there is a tensor of shape 4 where as all I passed was my input which was of shape 2.

These are the shapes that I have printed out. for reference - Image embedding size: torch.Size([8, 8]) Image embedding dimension: 256 Points shape: torch.Size([2, 2, 2]) torch.Size([2, 2]) Boxes shape: torch.Size([2, 1, 4]) Masks shape: torch.Size([2, 1, 128, 128]) Sparse Embeddings shape: torch.Size([2, 4, 256]) Dense Embeddings shape: torch.Size([2, 256, 32, 32])

however the shape of the src in the mask decoder I'm getting this - torch.Size([4, 256, 8, 8])

line number - 148 in file - maskdecoder.py src = src + self.interpolate(dense_prompt_embeddings, *src.shape[2:])

RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0

can you please help me out to solve this issue or guide on how I can check what the model is taking as input and giving out as output.

I have tried reshaping it however there are new errors which don't make sense to me.

I appreciate your help!!

ByungKwanLee commented 3 months ago

Can you write the toy example for this error?.. I did not replicate this kind of error message!

saimohan16 commented 3 months ago

Thank you for following up - this is code that I have been using -

code --------------------- import segment_anything.modeling.mask_decoder as mask_decoder importlib.reload(mask_decoder) from segment_anything.modeling.mask_decoder import MaskDecoder from segment_anything.modeling import ImageEncoderViT, PromptEncoder, TwoWayTransformer

Verify the correct version of mask_decoder is loaded

print("Reloaded mask_decoder.py contents:")

print_file_contents(mask_decoder_path)

batch_size = 2 input_image_size = (128, 128) # Input image size (H, W) mask_in_chans = 16 # Number of channels for mask embedding processing patch_size = 16

Initialize the image encoder

image_encoder = ImageEncoderViT(img_size=input_image_size[0], patch_size=patch_size)

Create a dummy input image and get the image embeddings

dummy_image = torch.randn(batch_size, 3, *input_image_size) image_embeddings = image_encoder(dummy_image)

Calculate image embedding size and dimension

image_embedding_size = image_embeddings.shape[-2:] # (H, W) of the embeddings embed_dim = image_embeddings.shape[1] # Embedding dimension

Print the image embedding size and dimension

print("Image embedding size:", image_embedding_size) print("Image embedding dimension:", embed_dim)

Points: tuple of (coordinates, labels)

points = ( torch.tensor([[[30.0, 40.0], [50.0, 60.0]], [[20.0, 80.0], [70.0, 90.0]]]), # Coordinates torch.tensor([[1, 0], [0, 1]]) # Labels: 1 for positive, 0 for negative )

Boxes: (batch_size, num_boxes, 4)

boxes = torch.tensor([[[25.0, 35.0, 55.0, 65.0]], [[20.0, 30.0, 50.0, 60.0]]])

Masks: (batch_size, 1, height, width)

masks = torch.randn(batch_size, 1, *input_image_size)

Initialize the PromptEncoder

prompt_encoder = PromptEncoder( embed_dim=embed_dim, image_embedding_size=image_embedding_size, input_image_size=input_image_size, mask_in_chans=mask_in_chans, )

Pass the dummy data through the encoder

sparse_embeddings, dense_embeddings = prompt_encoder(points=points, boxes=boxes, masks=masks)

Print the shapes of the inputs and outputs

print("Points shape:", points[0].shape, points[1].shape) print("Boxes shape:", boxes.shape) print("Masks shape:", masks.shape) print("Sparse Embeddings shape:", sparse_embeddings.shape) print("Dense Embeddings shape:", dense_embeddings.shape)

Initialize the MaskDecoder

transformer = TwoWayTransformer( depth=2, embedding_dim=embed_dim, mlp_dim=2048, num_heads=8, )

mask_decoder = MaskDecoder( num_multimask_outputs=3, transformer=transformer, transformer_dim=embed_dim, iou_head_depth=3, iou_head_hidden_dim=256, )

Forward pass through the MaskDecoder

masks_output, iou_pred = mask_decoder( image_embeddings=image_embeddings, image_pe=torch.zeros_like(image_embeddings), # Assuming positional encodings are zeros sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=True )

Print the shapes of the outputs

print("Masks output shape:", masks_output.shape) print("IoU predictions shape:", iou_pred.shape)