B-cos / B-cos-v2

Official PyTorch implementation of improved B-cos models
Apache License 2.0
39 stars 5 forks source link

matrix size doesn't match in vit #9

Closed fallcat closed 2 months ago

fallcat commented 2 months ago

Hi! Thank you for the contribution! I'm trying to use the code for ViT but ran into this bug:

Here is my code:

model = torch.hub.load("B-cos/B-cos-v2", "simple_vit_b_patch16_224", pretrained=True).to(device)
expl_out = model.explain(inputs)

The inputs is of shape (1,3,224,224)

And I got this bug:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_16052/590328325.py in <module>
----> 1 expl_out = model.explain(inputs)

~/.cache/torch/hub/B-cos_B-cos-v2_main/bcos/common.py in explain(self, in_tensor, idx, **grad2img_kwargs)
    163         with torch.enable_grad(), self.explanation_mode():
    164             # fwd + prediction
--> 165             out = self(in_tensor)  # noqa
    166             pred_out = out.max(1)
    167             result["prediction"] = pred_out.indices.item()

/opt/conda/envs/rapids/lib/python3.10/site-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/envs/rapids/lib/python3.10/site-packages/torch/nn/modules/container.py in forward(self, input)
    215     def forward(self, input):
    216         for module in self:
--> 217             input = module(input)
    218         return input
    219 

/opt/conda/envs/rapids/lib/python3.10/site-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

~/.cache/torch/hub/B-cos_B-cos-v2_main/bcos/models/vit.py in forward(self, img)
    318 
    319     def forward(self, img):
--> 320         x = self.to_patch_embedding(img)
    321         pe = self.positional_embedding(x)
    322         x = rearrange(x, "b ... d -> b (...) d") + pe

/opt/conda/envs/rapids/lib/python3.10/site-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/envs/rapids/lib/python3.10/site-packages/torch/nn/modules/container.py in forward(self, input)
    215     def forward(self, input):
    216         for module in self:
--> 217             input = module(input)
    218         return input
    219 

/opt/conda/envs/rapids/lib/python3.10/site-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

~/.cache/torch/hub/B-cos_B-cos-v2_main/bcos/modules/bcoslinear.py in forward(self, in_tensor)
     95         """
     96         # Simple linear layer
---> 97         out = self.linear(in_tensor)
     98 
     99         # MaxOut computation

/opt/conda/envs/rapids/lib/python3.10/site-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

~/.cache/torch/hub/B-cos_B-cos-v2_main/bcos/modules/bcoslinear.py in forward(self, input)
     25     def forward(self, input: Tensor) -> Tensor:
     26         w = self.weight / LA.vector_norm(self.weight, dim=1, keepdim=True)
---> 27         return F.linear(input, w, self.bias)
     28 
     29 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (196x768 and 1536x768)

Is this the correct way to use ViT or is there something I missed?

Thank you in advance!!

nps1ngh commented 2 months ago

Seems like you’re missing the input transform. B-cos models expect a 6 channel (RGB + inv. RGB) input.

The pretrained models have the transform attached to them, you can use them like so:

model = torch.hub.load("B-cos/B-cos-v2", "simple_vit_b_patch16_224", pretrained=True).to(device)
inputs = model.transform(inputs)
expl_out = model.explain(inputs)

LMK if it works!

Best, Navdeep

EDIT: the above transform expects a PIL image. For a RGB tensor input you can attach the inv. channels as follows:

from bcos.data.transforms import AddInverse

model = torch.hub.load("B-cos/B-cos-v2", "simple_vit_b_patch16_224", pretrained=True).to(device)
in_tensor = AddInverse()(in_tensor)
expl_out = model.explain(in_tensor)
fallcat commented 2 months ago

Thank you so much for the prompt response. Both work!