BAAI-DCAI / M3D

M3D: Advancing 3D Medical Image Analysis with Multi-Modal Large Language Models
MIT License
211 stars 12 forks source link

Issue while using demo file. (Tensor on device cuda:0 is not on the expected device meta!) #19

Closed hari3100 closed 3 weeks ago

hari3100 commented 2 months ago

I ran the demo.py with few modifications , as it was showing some errors, the demo.py:

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
import numpy as np
import torch
from dataclasses import dataclass, field
import simple_slice_viewer as ssv
import SimpleITK as sikt
# from LaMed.src.model.language_model import *
import matplotlib.pyplot as plt

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

@dataclass
class AllArguments:
    model_name_or_path: str = field(default="GoodBaiBai88/M3D-LaMed-Llama-2-7B")

    proj_out_num: int = field(default=256, metadata={"help": "Number of output tokens in Projector."})
    image_path: str = field(default="A:\\AL DL\\CA project\\M3D\\Data\\data\\examples\\example_06.npy")

def main():
    seed_everything(42)
    dtype = torch.float16  # or bfloat16, float16, float32

    parser = transformers.HfArgumentParser(AllArguments)
    args = parser.parse_args_into_dataclasses()[0]

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        model_max_length=512,
        padding_side="right",
        use_fast=False,
        trust_remote_code=True
    )

    # Load the model with accelerate managing device mapping
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype=dtype,
        device_map="auto",  # Let accelerate handle the device assignment
        trust_remote_code=True
    )

    # Rest of the code for inference
    question = "What is liver in this image? Please output the segmentation mask."

    image_tokens = "<im_patch>" * args.proj_out_num
    input_txt = image_tokens + question
    input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to('cuda:0')  # Explicitly move input to cuda:0

    image_np = np.load(args.image_path)
    image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device='cuda:0')  # Explicitly move image to cuda:0

    # Use model.generate with inputs moved to the correct device
    generation, seg_logit = model.generate(image_pt, input_id, seg_enable=True, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)

    generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
    seg_mask = (torch.sigmoid(seg_logit) > 0.5) * 1.0

    print('question', question)
    print('generated_texts', generated_texts[0])

    # Display the image and segmentation mask
    image = sikt.GetImageFromArray(image_np)
    ssv.display(image)

    seg = sikt.GetImageFromArray(seg_mask.cpu().numpy()[0])  # Move mask back to CPU for display
    ssv.display(seg)

if __name__ == "__main__":
    main()

But Im still getting this error no matter what I have tried :

(m3d) A:\AL DL\CA project\M3D\LaMed\src\demo>python mydemo.py C:\Users\asus.conda\envs\m3d\lib\site-packages\huggingface_hub\file_download.py:1150: FutureWarning: resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True. warnings.warn( Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. C:\Users\asus.conda\envs\m3d\lib\site-packages\monai\utils\deprecate_utils.py:221: FutureWarning: monai.networks.blocks.patchembedding PatchEmbeddingBlock.init:pos_embed: Argument pos_embed has been deprecated since version 1.2. It will be removed in version 1.4. please use proj_type instead. warn_deprecated(argname, msg, warning_category) build_sam_vit_3d... C:\Users\asus.conda\envs\m3d\lib\site-packages\monai\utils\deprecate_utils.py:221: FutureWarning: monai.networks.nets.vit ViT.init:pos_embed: Argument pos_embed has been deprecated since version 1.2. It will be removed in version 1.4. please use proj_type instead. warn_deprecated(argname, msg, warning_category) Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [01:38<00:00, 16.50s/it] C:\Users\asus.conda\envs\m3d\lib\site-packages\huggingface_hub\file_download.py:1150: FutureWarning: resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True. warnings.warn( Some parameters are on the meta device because they were offloaded to the disk and cpu. C:\Users\asus.conda\envs\m3d\lib\site-packages\transformers\models\llama\modeling_llama.py:670: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:263.) attn_output = torch.nn.functional.scaled_dot_product_attention( Traceback (most recent call last): File "A:\AL DL\CA project\M3D\LaMed\src\demo\mydemo.py", line 80, in main() File "A:\AL DL\CA project\M3D\LaMed\src\demo\mydemo.py", line 63, in main generation, seg_logit = model.generate(image_pt, input_id, seg_enable=True, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0) File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "C:\Users\asus.cache\huggingface\modules\transformers_modules\GoodBaiBai88\M3D-LaMed-Llama-2-7B\87f6f56001c2dd4d005c0d10c17edba4ffbf36a5\modeling_m3d_lamed.py", line 2085, in generate logits = self.get_model().seg_module(images, seg_prompts) File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "C:\Users\asus.cache\huggingface\modules\transformers_modules\GoodBaiBai88\M3D-LaMed-Llama-2-7B\87f6f56001c2dd4d005c0d10c17edba4ffbf36a5\modeling_m3d_lamed.py", line 1385, in forward logits = self.forward_decoder(image_embedding, img_shape, text_emb=text_emb, text=text, boxes=boxes, points=points) File "C:\Users\asus.cache\huggingface\modules\transformers_modules\GoodBaiBai88\M3D-LaMed-Llama-2-7B\87f6f56001c2dd4d005c0d10c17edba4ffbf36a5\modeling_m3d_lamed.py", line 1391, in forward_decoder sparse_embeddings, dense_embeddings = self.prompt_encoder( File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl return forward_call(*args, *kwargs) File "C:\Users\asus.conda\envs\m3d\lib\site-packages\accelerate\hooks.py", line 170, in new_forward output = module._old_forward(args, kwargs) File "C:\Users\asus.cache\huggingface\modules\transformers_modules\GoodBaiBai88\M3D-LaMed-Llama-2-7B\87f6f56001c2dd4d005c0d10c17edba4ffbf36a5\modeling_m3d_lamed.py", line 839, in forward sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1) File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch_prims_common\wrappers.py", line 250, in _fn result = fn(args, kwargs) File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch_prims_common\wrappers.py", line 137, in _fn result = fn(bound.arguments) File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch_refs__init__.py", line 2702, in cat utils.check_same_device(tensors, allow_cpu_scalar_tensors=False) File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch_prims_common__init__.py", line 696, in check_same_device raise RuntimeError(msg) RuntimeError: Tensor on device cuda:0 is not on the expected device meta!

=================== I have python 3.9.9, cuda 11.8 and the rest of the requirements are according to the requirements txt. I have a gtx 1650 GPU.

Any help would be greatly appreciated,

Thank you @baifanxxx , @Yuxin-Du-Lab

baifanxxx commented 2 months ago

Hi,

I'm not sure, but can the gtx1650 really load LLM?......