huggingface / transformers

šŸ¤— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.92k stars 26.27k forks source link

LLava not working with accelerate dispatch: "Expected all tensors to be on the same device" #27917

Closed py4 closed 9 months ago

py4 commented 9 months ago

System Info

Who can help?

@pacman100 @ArthurZucker

Information

Tasks

Reproduction

from PIL import Image

import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

model_id = "llava-hf/llava-1.5-7b-hf"

prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"

model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
    device_map='auto'
)
processor = AutoProcessor.from_pretrained(model_id)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to('cuda', torch.float16)

output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True))

Expected behavior

It should produce the output but I get the following. I believe something similar to this is needed to fix

  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/transformers/generation/utils.py", line 1718, in generate
    return self.greedy_search(
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/transformers/generation/utils.py", line 2579, in greedy_search
    outputs = self(
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/transformers/models/llava/modeling_llava.py", line 433, in forward
    outputs = self.language_model(
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1174, in forward
    outputs = self.model(
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1061, in forward
    layer_outputs = decoder_layer(
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 789, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  File "/home/pooyam/hf_llava/lib/python3.9/site-packages/transformers/cache_utils.py", line 127, in update
    self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument tensors in method wrapper_CUDA_cat)
ArthurZucker commented 9 months ago

cc @younesbelkada I think this worked / works on main no?

FYI @gante and @tomaarsen we'll work on a fix with @younesbelkada

RylanSchaeffer commented 6 months ago

I'm hitting a similar error with Llava 1.6 + Mistral 7B. What was the fix previously and what changes (if any) need to be made for Mistral 7B?

    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper_CUDA__native_layer_norm)
RylanSchaeffer commented 6 months ago

The problem arises due to:

<bound method AlignDevicesHook.pre_forward of AlignDevicesHook(execution_device=1, offload=False, io_same_device=False, offload_buffers=False, place_submodules=True, skip_keys=['past_key_values', 'causal_mask'])>

which moves the the input tensors from the "correct" devices (in my setting, 0) to the wrong device (in my setting, 1).

ArthurZucker commented 6 months ago

Please @RylanSchaeffer do not re-open the same discussion on various threads: https://github.com/huggingface/transformers/pull/28051#issuecomment-1976118206 let's keep the discussion on a single issue