Open radna0 opened 2 months ago
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hey! Don't think we officially test nor support TPU for this model 🤗 I can't really reproduce 😢 @tengomucho might have an idea
@radna0 transformers does not support officially TPU, but I think things might work if:
System Info
transformers
version: 4.45.0.dev0Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Following this Qwen2-VL guide => https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct#quickstart
from torch.distributed._tensor import DeviceMesh, distribute_module from torch_xla.distributed.spmd import auto_policy
from torch_xla import runtime as xr from torch_xla.experimental.spmd_fully_sharded_data_parallel import ( _prepare_spmd_partition_spec, SpmdFullyShardedDataParallel as FSDPv2, )
import time
start = time.time()
device = xm.xla_device()
default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, attn_implementation="eager", ).to(device)
print(model.device)
default processer
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4")
message = [ { "role": "user", "content": [ { "type": "image", "image": "image1.jpg", }, {"type": "text", "text": "Describe this image in detail."}, ], } ]
allmessages = [[message] for in range(1)] for messages in all_messages:
print(f"Time taken: {time.time() - start}")
kojoe@t1v-n-cb70f560-w-0:~/EasyAnimate/easyanimate/image_caption$ python caption.py WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. Loading checkpoint shards: 100%|██████████████████████████████████████████████████| 2/2 [00:00<00:00, 4.39it/s] xla:0