TRI-ML / prismatic-vlms

A flexible and efficient codebase for training visually-conditioned language models (VLMs)
MIT License
425 stars 194 forks source link

Do you have plan to add SAM as a visual encoder? #10

Open StarCycle opened 6 months ago

StarCycle commented 6 months ago

SAM can be used with Siglip/CLIP

For example, Vary uses SAM+CLIP, and Deepseek-VL uses Siglip+SAM.

Would you like to try them with this codebase?

siddk commented 6 months ago

We absolutely can; just to confirm, is this the model you'd want us to try adding: https://huggingface.co/timm/samvit_base_patch16.sa1b?

CC @ashwin-balakrishna96 to add to our internal run list!

StarCycle commented 6 months ago

Hello @siddk @ashwin-balakrishna96

Yes! Please try the SAM-base! Here are some experience from my colleague:

Best, StarCycle

StarCycle commented 6 months ago

I guess the training pipeline is the same for Dinov2.

Please let me know if you find anything interesting!

Best, StarCycle

StarCycle commented 6 months ago

You can just start with this:

from torch import nn
from urllib.request import urlopen
from PIL import Image
import timm

class DownSampledSAMVit(nn.Module):
  def __init__(self, name, downsample_channels=(512,1024)):
    super().__init__()
    self.SAMViT = timm.create_model(
      name,
      pretrained=True,
      num_classes=0,  # remove classifier nn.Linear
    )
    data_config = timm.data.resolve_model_data_config(self.SAMViT)
    self.transforms = timm.data.create_transform(**data_config, is_training=False)

    in_channels = self.SAMViT.neck[-1].weight.shape[0]
    downsamples = []
    for i in range(len(downsample_channels)):
      out_channels = downsample_channels[i]
      downsamples.append(
        nn.Conv2d(
          in_channels,
          out_channels,
          kernel_size=3,
          stride=2,
          padding=1,
          bias=False,
        )
      )
      in_channels = out_channels
    self.downsamples = nn.Sequential(*downsamples)

  def forward(self, rgb):
    out = self.SAMViT.forward_features(rgb)
    out = self.downsamples(out)
    return out

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = DownSampledSAMVit('samvit_base_patch16.sa1b').cuda().eval()
transforms = model.transforms
output = model(transforms(img).unsqueeze(0).cuda())
ashwin-balakrishna96 commented 5 months ago

@StarCycle thanks a bunch for the suggestion. We can try integrating the SAM baseline in a week or so, but if you have cycles and would be interested in opening up a PR to integrate it in the meanwhile (especially because it seems like you've already been thinking about how the code should look), we would also be very happy to review it and integrate it into Prismatic :)