huggingface / optimum-habana

Easy and lightning fast training of 🤗 Transformers on Habana Gaudi processor (HPU)
Apache License 2.0
143 stars 175 forks source link

llava inference works incorrectly if adapt_transformers_to_gaudi called after transformers import #1176

Open mattolson93 opened 1 month ago

mattolson93 commented 1 month ago

System Info

Optimum Habana version v1.12.1
Synapse 1.16.2
docker vault.habana.ai/gaudi-docker/1.16.2/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest

Information

Tasks

Reproduction

Running minimal llava inference code fails (on cpu or hpu). I have also verified on original llava (if you want that example). Change line 7 to false to verify it the correct behavior.

import requests
from PIL import Image
from transformers import AutoProcessor
from habana_frameworks.torch import hpu
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

fails = True
if fails: 
  from transformers import LlavaForConditionalGeneration
  adapt_transformers_to_gaudi()
else:
  adapt_transformers_to_gaudi()
  from transformers import LlavaForConditionalGeneration

checkpoint = "Intel/llava-gemma-2b"

# Load model
model = LlavaForConditionalGeneration.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)

# Prepare inputs
# Use gemma chat template
prompt = processor.tokenizer.apply_chat_template(
    [{'role': 'user', 'content': "<image>\nWhat's the content of the image?"}],
    tokenize=False,
    add_generation_prompt=True
)
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors="pt")

# Generate
generate_ids = model.generate(**inputs, max_length=30)
output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(output)

Expected behavior

the output should say "The image features a red stop sign on a" rather than " ss wasteAsian\n\n\nt".

regisss commented 1 month ago

Yes, that's expected. A workaround would be to fully re-import all the imports that are done in adapt_transformers_to_gaudi, but it's not that straightforward to do. Let's see if some other people encounter this issue and in that case I'll give a higher priority to it.

mattolson93 commented 1 month ago

Would it be hard to throw a warning? I can imagine myself running into this problem again in a few months and not realizing I put the call in the wrong spot.