Open StarCycle opened 6 months ago
We absolutely can; just to confirm, is this the model you'd want us to try adding: https://huggingface.co/timm/samvit_base_patch16.sa1b?
CC @ashwin-balakrishna96 to add to our internal run list!
Hello @siddk @ashwin-balakrishna96
Yes! Please try the SAM-base! Here are some experience from my colleague:
If you want to concatenate the SAM output with the Siglip output, you may need to add 2 convolution layers after the SAM-base, to change the output size from [64, 64, 256] to [256, 1024]. You can check this figure or the Vary paper.
It's possible to use the SAM-base as the only visual encoder. But you need pretraining to align SAM-base to LLM embedding space using a small language model (e.g., OPT-125M). You may need multiple epochs in this phase
Best, StarCycle
I guess the training pipeline is the same for Dinov2.
Please let me know if you find anything interesting!
Best, StarCycle
You can just start with this:
from torch import nn
from urllib.request import urlopen
from PIL import Image
import timm
class DownSampledSAMVit(nn.Module):
def __init__(self, name, downsample_channels=(512,1024)):
super().__init__()
self.SAMViT = timm.create_model(
name,
pretrained=True,
num_classes=0, # remove classifier nn.Linear
)
data_config = timm.data.resolve_model_data_config(self.SAMViT)
self.transforms = timm.data.create_transform(**data_config, is_training=False)
in_channels = self.SAMViT.neck[-1].weight.shape[0]
downsamples = []
for i in range(len(downsample_channels)):
out_channels = downsample_channels[i]
downsamples.append(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False,
)
)
in_channels = out_channels
self.downsamples = nn.Sequential(*downsamples)
def forward(self, rgb):
out = self.SAMViT.forward_features(rgb)
out = self.downsamples(out)
return out
img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
model = DownSampledSAMVit('samvit_base_patch16.sa1b').cuda().eval()
transforms = model.transforms
output = model(transforms(img).unsqueeze(0).cuda())
@StarCycle thanks a bunch for the suggestion. We can try integrating the SAM baseline in a week or so, but if you have cycles and would be interested in opening up a PR to integrate it in the meanwhile (especially because it seems like you've already been thinking about how the code should look), we would also be very happy to review it and integrate it into Prismatic :)
SAM can be used with Siglip/CLIP
For example, Vary uses SAM+CLIP, and Deepseek-VL uses Siglip+SAM.
Would you like to try them with this codebase?