$ python initialize_with_sam.py
Traceback (most recent call last):
File "/home/pape/Work/my_projects/torch-em/experiments/vision-transformer/unetr/initialize_with_sam.py", line 8, in <module>
y = model(x)
File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/pape/Work/my_projects/torch-em/torch_em/model/unetr.py", line 209, in forward
z12, from_encoder = self.encoder(x)
File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/pape/Work/my_projects/torch-em/torch_em/model/unetr.py", line 54, in forward
x = blk(x)
File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/pape/Work/my_projects/SegmentAnythingModel/segment_anything/modeling/image_encoder.py", line 174, in forward
x = self.attn(x)
File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/pape/Work/my_projects/SegmentAnythingModel/segment_anything/modeling/image_encoder.py", line 227, in forward
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x768 and 64x27)
I created https://github.com/constantinpape/torch-em/blob/main/experiments/vision-transformer/unetr/initialize_with_sam.py to check it. But it fails when trying to feed a tensor into it.
cc @anwai98