from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import numpy as np
import torch
import torch_xla as xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy
from torch_xla import runtime as xr
from torch_xla.experimental.spmd_fully_sharded_data_parallel import (
_prepare_spmd_partition_spec,
SpmdFullyShardedDataParallel as FSDPv2,
)
allmessages = [[message] for in range(1)]
for messages in all_messages:
# Preparation for inference
texts = [
processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
for msg in messages
]
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=texts,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
for i, text in enumerate(output_text):
print(f"Output {i}: {text}")
print(f"Time taken: {time.time() - start}")
<!-- If you have a code sample, error messages, stack traces, please provide it here as well. Or better use the Colab template: https://github.com/pytorch/xla/blob/master/contrib/colab/issue-report.ipynb -->
## Expected behavior
<!-- A clear and concise description of what you expected to happen. -->
Should run the Qwen2-VL-2B-Instruct model at https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct#quickstart
## Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
- torch_xla version: NIghtly 2.5
## Additional context
<!-- Add any other context about the problem here. -->
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
from torch.distributed._tensor import DeviceMesh, distribute_module from torch_xla.distributed.spmd import auto_policy
from torch_xla import runtime as xr from torch_xla.experimental.spmd_fully_sharded_data_parallel import ( _prepare_spmd_partition_spec, SpmdFullyShardedDataParallel as FSDPv2, )
import time
xla.experimental.eager_mode(True) start = time.time()
device = xla.device()
default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", torch_dtype=torch.bfloat16, attn_implementation="eager", device_map="auto", ).to(device)
print(model.device)
default processer
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4")
message = [ { "role": "user", "content": [ { "type": "image", "image": "https://w0.peakpx.com/wallpaper/607/308/HD-wallpaper-anime-girl-black-hair-guitar-instrument-red-eyes-school-uniform-skirt.jpg", }, {"type": "text", "text": "Describe this image in detail."}, ], } ]
allmessages = [[message] for in range(1)] for messages in all_messages:
print(f"Time taken: {time.time() - start}")