BAAI-DCAI / M3D

M3D: Advancing 3D Medical Image Analysis with Multi-Modal Large Language Models
MIT License
181 stars 11 forks source link

Fail to run M3DClip following the example on huggingface #25

Open Thedatababbler opened 5 days ago

Thedatababbler commented 5 days ago

Hi, thank you for your awesome works. However, when I was trying to run the M3DClip model using code on huggingface I have some errors related to the einops lib. I noticed you use the monai ViT layers to implement the M3DClip, that might be the problem.

Code I ran:

import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
import SimpleITK as sikt

device = torch.device("cuda") # or cpu

tokenizer = AutoTokenizer.from_pretrained(
    "GoodBaiBai88/M3D-CLIP",
    model_max_length=512,
    padding_side="right",
    use_fast=False
)
model = AutoModel.from_pretrained(
    "GoodBaiBai88/M3D-CLIP",
    trust_remote_code=True
)
model = model.to(device=device)

# Prepare your 3D medical image:
# 1. The image shape needs to be processed as 1*32*256*256, considering resize and other methods.
# 2. The image needs to be normalized to 0-1, considering Min-Max Normalization.
# 3. The image format needs to be converted to .npy 
# 4. Although we did not train on 2D images, in theory, the 2D image can be interpolated to the shape of 1*32*256*256 for input.

image_path = "./Data/data/examples/example_00.npy"
input_txt = "Liver"

text_tensor = tokenizer(input_txt, max_length=512, truncation=True, padding="max_length", return_tensors="pt")
input_id = text_tensor["input_ids"].to(device=device)
attention_mask = text_tensor["attention_mask"].to(device=device)
image = np.load(image_path)
# image = torch.from_numpy(image).to(device=device)
import pdb;pdb.set_trace()
with torch.inference_mode():
    image_features = model.encode_image(image)[:, 0]
    text_features = model.encode_text(input_id, attention_mask)[:, 0]

error message:

Traceback (most recent call last):
  File "/local/scratch/zqin30/projects/M3D/m3dclip.py", line 32, in <module>
    image = np.load(image_path).to(devie=device)
            ^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'numpy.ndarray' object has no attribute 'to'
(test) (base) zqin30@h100server:/local/scratch/zqin30/projects/M3D$ python m3dclip.py 
/local/scratch/zqin30/envs/test/lib/python3.11/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
/local/scratch/zqin30/envs/test/lib/python3.11/site-packages/monai/utils/deprecate_utils.py:221: FutureWarning: monai.networks.blocks.patchembedding PatchEmbeddingBlock.__init__:pos_embed: Argument `pos_embed` has been deprecated since version 1.2. It will be removed in version 1.4. please use `proj_type` instead.
  warn_deprecated(argname, msg, warning_category)
/local/scratch/zqin30/envs/test/lib/python3.11/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
> /local/scratch/zqin30/projects/M3D/m3dclip.py(35)<module>()
-> with torch.inference_mode():
(Pdb) p model.encode_image(image
baifanxxx commented 1 day ago

Hi,

I think you should transfer the numpy array to torch tensor, like torch.from_numpy(image)

AndyQin92 commented 1 day ago

Now I have this message: ` Traceback (most recent call last): File "/root/projects/test/M3D/medclip.py", line 36, in image_features = model.encode_image(image)[:, 0] ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/huggingface/modules/transformers_modules/GoodBaiBai88/M3D-CLIP/ae091d89a0ef38b533ecc4ed21426f7658853963/modeling_m3d_clip.py", line 185, in encode_image imagefeats, = self.vision_encoder(image) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda/envs/m3d/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda/envs/m3d/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/.cache/huggingface/modules/transformers_modules/GoodBaiBai88/M3D-CLIP/ae091d89a0ef38b533ecc4ed21426f7658853963/modeling_m3d_clip.py", line 140, in forward x = self.patch_embedding(x) ^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda/envs/m3d/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda/envs/m3d/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda/envs/m3d/lib/python3.11/site-packages/monai/networks/blocks/patchembedding.py", line 141, in forward x = self.patch_embeddings(x) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda/envs/m3d/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda/envs/m3d/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda/envs/m3d/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) ^^^^^^^^^^^^^ File "/root/miniconda/envs/m3d/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda/envs/m3d/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda/envs/m3d/lib/python3.11/site-packages/einops/layers/torch.py", line 14, in forward recipe = self._multirecipe[input.ndim]


KeyError: 4
`