constantinpape / torch-em

Deep-learning based semantic and instance segmentation for 3D Electron Microscopy and other bioimage analysis problems based on pytorch.
MIT License
76 stars 22 forks source link

UNETR with SAM initialization is not working yet #147

Closed constantinpape closed 1 year ago

constantinpape commented 1 year ago

I created https://github.com/constantinpape/torch-em/blob/main/experiments/vision-transformer/unetr/initialize_with_sam.py to check it. But it fails when trying to feed a tensor into it.

cc @anwai98

$ python initialize_with_sam.py 
Traceback (most recent call last):
  File "/home/pape/Work/my_projects/torch-em/experiments/vision-transformer/unetr/initialize_with_sam.py", line 8, in <module>
    y = model(x)
  File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pape/Work/my_projects/torch-em/torch_em/model/unetr.py", line 209, in forward
    z12, from_encoder = self.encoder(x)
  File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pape/Work/my_projects/torch-em/torch_em/model/unetr.py", line 54, in forward
    x = blk(x)
  File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pape/Work/my_projects/SegmentAnythingModel/segment_anything/modeling/image_encoder.py", line 174, in forward
    x = self.attn(x)
  File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pape/Work/my_projects/SegmentAnythingModel/segment_anything/modeling/image_encoder.py", line 227, in forward
    qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pape/software/conda/miniconda3/envs/sam/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x768 and 64x27)
constantinpape commented 1 year ago

This is working now.