Closed fallcat closed 2 months ago
Seems like you’re missing the input transform. B-cos models expect a 6 channel (RGB + inv. RGB) input.
The pretrained models have the transform attached to them, you can use them like so:
model = torch.hub.load("B-cos/B-cos-v2", "simple_vit_b_patch16_224", pretrained=True).to(device)
inputs = model.transform(inputs)
expl_out = model.explain(inputs)
LMK if it works!
Best, Navdeep
EDIT: the above transform expects a PIL image. For a RGB tensor input you can attach the inv. channels as follows:
from bcos.data.transforms import AddInverse
model = torch.hub.load("B-cos/B-cos-v2", "simple_vit_b_patch16_224", pretrained=True).to(device)
in_tensor = AddInverse()(in_tensor)
expl_out = model.explain(in_tensor)
Thank you so much for the prompt response. Both work!
Hi! Thank you for the contribution! I'm trying to use the code for ViT but ran into this bug:
Here is my code:
The inputs is of shape (1,3,224,224)
And I got this bug:
Is this the correct way to use ViT or is there something I missed?
Thank you in advance!!