I ran the demo.py with few modifications , as it was showing some errors,
the demo.py:
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
import numpy as np
import torch
from dataclasses import dataclass, field
import simple_slice_viewer as ssv
import SimpleITK as sikt
# from LaMed.src.model.language_model import *
import matplotlib.pyplot as plt
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(seed)
@dataclass
class AllArguments:
model_name_or_path: str = field(default="GoodBaiBai88/M3D-LaMed-Llama-2-7B")
proj_out_num: int = field(default=256, metadata={"help": "Number of output tokens in Projector."})
image_path: str = field(default="A:\\AL DL\\CA project\\M3D\\Data\\data\\examples\\example_06.npy")
def main():
seed_everything(42)
dtype = torch.float16 # or bfloat16, float16, float32
parser = transformers.HfArgumentParser(AllArguments)
args = parser.parse_args_into_dataclasses()[0]
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
model_max_length=512,
padding_side="right",
use_fast=False,
trust_remote_code=True
)
# Load the model with accelerate managing device mapping
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=dtype,
device_map="auto", # Let accelerate handle the device assignment
trust_remote_code=True
)
# Rest of the code for inference
question = "What is liver in this image? Please output the segmentation mask."
image_tokens = "<im_patch>" * args.proj_out_num
input_txt = image_tokens + question
input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to('cuda:0') # Explicitly move input to cuda:0
image_np = np.load(args.image_path)
image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device='cuda:0') # Explicitly move image to cuda:0
# Use model.generate with inputs moved to the correct device
generation, seg_logit = model.generate(image_pt, input_id, seg_enable=True, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
seg_mask = (torch.sigmoid(seg_logit) > 0.5) * 1.0
print('question', question)
print('generated_texts', generated_texts[0])
# Display the image and segmentation mask
image = sikt.GetImageFromArray(image_np)
ssv.display(image)
seg = sikt.GetImageFromArray(seg_mask.cpu().numpy()[0]) # Move mask back to CPU for display
ssv.display(seg)
if __name__ == "__main__":
main()
But Im still getting this error no matter what I have tried :
(m3d) A:\AL DL\CA project\M3D\LaMed\src\demo>python mydemo.py
C:\Users\asus.conda\envs\m3d\lib\site-packages\huggingface_hub\file_download.py:1150: FutureWarning: resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True.
warnings.warn(
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
C:\Users\asus.conda\envs\m3d\lib\site-packages\monai\utils\deprecate_utils.py:221: FutureWarning: monai.networks.blocks.patchembedding PatchEmbeddingBlock.init:pos_embed: Argument pos_embed has been deprecated since version 1.2. It will be removed in version 1.4. please use proj_type instead.
warn_deprecated(argname, msg, warning_category)
build_sam_vit_3d...
C:\Users\asus.conda\envs\m3d\lib\site-packages\monai\utils\deprecate_utils.py:221: FutureWarning: monai.networks.nets.vit ViT.init:pos_embed: Argument pos_embed has been deprecated since version 1.2. It will be removed in version 1.4. please use proj_type instead.
warn_deprecated(argname, msg, warning_category)
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [01:38<00:00, 16.50s/it]
C:\Users\asus.conda\envs\m3d\lib\site-packages\huggingface_hub\file_download.py:1150: FutureWarning: resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True.
warnings.warn(
Some parameters are on the meta device because they were offloaded to the disk and cpu.
C:\Users\asus.conda\envs\m3d\lib\site-packages\transformers\models\llama\modeling_llama.py:670: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:263.)
attn_output = torch.nn.functional.scaled_dot_product_attention(
Traceback (most recent call last):
File "A:\AL DL\CA project\M3D\LaMed\src\demo\mydemo.py", line 80, in
main()
File "A:\AL DL\CA project\M3D\LaMed\src\demo\mydemo.py", line 63, in main
generation, seg_logit = model.generate(image_pt, input_id, seg_enable=True, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context
return func(*args, kwargs)
File "C:\Users\asus.cache\huggingface\modules\transformers_modules\GoodBaiBai88\M3D-LaMed-Llama-2-7B\87f6f56001c2dd4d005c0d10c17edba4ffbf36a5\modeling_m3d_lamed.py", line 2085, in generate
logits = self.get_model().seg_module(images, seg_prompts)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
return forward_call(args, kwargs)
File "C:\Users\asus.cache\huggingface\modules\transformers_modules\GoodBaiBai88\M3D-LaMed-Llama-2-7B\87f6f56001c2dd4d005c0d10c17edba4ffbf36a5\modeling_m3d_lamed.py", line 1385, in forward
logits = self.forward_decoder(image_embedding, img_shape, text_emb=text_emb, text=text, boxes=boxes, points=points)
File "C:\Users\asus.cache\huggingface\modules\transformers_modules\GoodBaiBai88\M3D-LaMed-Llama-2-7B\87f6f56001c2dd4d005c0d10c17edba4ffbf36a5\modeling_m3d_lamed.py", line 1391, in forward_decoder
sparse_embeddings, dense_embeddings = self.prompt_encoder(
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, kwargs)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
return forward_call(*args, *kwargs)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\accelerate\hooks.py", line 170, in new_forward
output = module._old_forward(args, kwargs)
File "C:\Users\asus.cache\huggingface\modules\transformers_modules\GoodBaiBai88\M3D-LaMed-Llama-2-7B\87f6f56001c2dd4d005c0d10c17edba4ffbf36a5\modeling_m3d_lamed.py", line 839, in forward
sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch_prims_common\wrappers.py", line 250, in _fn
result = fn(args, kwargs)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch_prims_common\wrappers.py", line 137, in _fn
result = fn(bound.arguments)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch_refs__init__.py", line 2702, in cat
utils.check_same_device(tensors, allow_cpu_scalar_tensors=False)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch_prims_common__init__.py", line 696, in check_same_device
raise RuntimeError(msg)
RuntimeError: Tensor on device cuda:0 is not on the expected device meta!
===================
I have python 3.9.9, cuda 11.8 and the rest of the requirements are according to the requirements txt.
I have a gtx 1650 GPU.
I ran the demo.py with few modifications , as it was showing some errors, the demo.py:
But Im still getting this error no matter what I have tried :
(m3d) A:\AL DL\CA project\M3D\LaMed\src\demo>python mydemo.py C:\Users\asus.conda\envs\m3d\lib\site-packages\huggingface_hub\file_download.py:1150: FutureWarning:
main()
File "A:\AL DL\CA project\M3D\LaMed\src\demo\mydemo.py", line 63, in main
generation, seg_logit = model.generate(image_pt, input_id, seg_enable=True, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context
return func(*args, kwargs)
File "C:\Users\asus.cache\huggingface\modules\transformers_modules\GoodBaiBai88\M3D-LaMed-Llama-2-7B\87f6f56001c2dd4d005c0d10c17edba4ffbf36a5\modeling_m3d_lamed.py", line 2085, in generate
logits = self.get_model().seg_module(images, seg_prompts)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
return forward_call(args, kwargs)
File "C:\Users\asus.cache\huggingface\modules\transformers_modules\GoodBaiBai88\M3D-LaMed-Llama-2-7B\87f6f56001c2dd4d005c0d10c17edba4ffbf36a5\modeling_m3d_lamed.py", line 1385, in forward
logits = self.forward_decoder(image_embedding, img_shape, text_emb=text_emb, text=text, boxes=boxes, points=points)
File "C:\Users\asus.cache\huggingface\modules\transformers_modules\GoodBaiBai88\M3D-LaMed-Llama-2-7B\87f6f56001c2dd4d005c0d10c17edba4ffbf36a5\modeling_m3d_lamed.py", line 1391, in forward_decoder
sparse_embeddings, dense_embeddings = self.prompt_encoder(
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, kwargs)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
return forward_call(*args, *kwargs)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\accelerate\hooks.py", line 170, in new_forward
output = module._old_forward(args, kwargs)
File "C:\Users\asus.cache\huggingface\modules\transformers_modules\GoodBaiBai88\M3D-LaMed-Llama-2-7B\87f6f56001c2dd4d005c0d10c17edba4ffbf36a5\modeling_m3d_lamed.py", line 839, in forward
sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch_prims_common\wrappers.py", line 250, in _fn
result = fn(args, kwargs)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch_prims_common\wrappers.py", line 137, in _fn
result = fn(bound.arguments)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch_refs__init__.py", line 2702, in cat
utils.check_same_device(tensors, allow_cpu_scalar_tensors=False)
File "C:\Users\asus.conda\envs\m3d\lib\site-packages\torch_prims_common__init__.py", line 696, in check_same_device
raise RuntimeError(msg)
RuntimeError: Tensor on device cuda:0 is not on the expected device meta!
resume_download
is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, useforce_download=True
. warnings.warn( Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. C:\Users\asus.conda\envs\m3d\lib\site-packages\monai\utils\deprecate_utils.py:221: FutureWarning: monai.networks.blocks.patchembedding PatchEmbeddingBlock.init:pos_embed: Argumentpos_embed
has been deprecated since version 1.2. It will be removed in version 1.4. please useproj_type
instead. warn_deprecated(argname, msg, warning_category) build_sam_vit_3d... C:\Users\asus.conda\envs\m3d\lib\site-packages\monai\utils\deprecate_utils.py:221: FutureWarning: monai.networks.nets.vit ViT.init:pos_embed: Argumentpos_embed
has been deprecated since version 1.2. It will be removed in version 1.4. please useproj_type
instead. warn_deprecated(argname, msg, warning_category) Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [01:38<00:00, 16.50s/it] C:\Users\asus.conda\envs\m3d\lib\site-packages\huggingface_hub\file_download.py:1150: FutureWarning:resume_download
is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, useforce_download=True
. warnings.warn( Some parameters are on the meta device because they were offloaded to the disk and cpu. C:\Users\asus.conda\envs\m3d\lib\site-packages\transformers\models\llama\modeling_llama.py:670: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:263.) attn_output = torch.nn.functional.scaled_dot_product_attention( Traceback (most recent call last): File "A:\AL DL\CA project\M3D\LaMed\src\demo\mydemo.py", line 80, in=================== I have python 3.9.9, cuda 11.8 and the rest of the requirements are according to the requirements txt. I have a gtx 1650 GPU.
Any help would be greatly appreciated,
Thank you @baifanxxx , @Yuxin-Du-Lab