OpenGVLab / InternVL

[CVPR 2024 Oral] InternVL Family: A Pioneering Open-Source Alternative to GPT-4o. 接近GPT-4o表现的开源多模态对话模型
https://internvl.readthedocs.io/en/latest/
MIT License
6.09k stars 475 forks source link

[Bug] Huggingface model works in cuda:0 but not cuda:1 #452

Open ControlNet opened 3 months ago

ControlNet commented 3 months ago

Checklist

Describe the bug

Using the minimal reproduction from the documentation, but load the model in another gpu not cuda:0, such as cuda:1. The chat method will fail to generate response.

In documentation, the model is loaded as

model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True).eval().cuda()

It works.

But the following code,

model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True).eval().cuda("cuda:1")  # note the model is loaded to cuda:1 not cuda:0!

The model can be loaded to the correct model, but fail to run chat method.

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

I think it is due to these lines. https://github.com/OpenGVLab/InternVL/blob/6a230b34cc04eb2ee51c3ea013362a57ab6a6dc9/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py#L288-L289

It should be ...to("<DEVICE>") rather than just .cuda().

Reproduction

import numpy as np
import torch
import torchvision.transforms as T
from decord import VideoReader, cpu
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

# If you have an 80G A100 GPU, you can put the entire model on a single GPU.
# Otherwise, you need to load a model using multiple GPUs, please refer to the `Multiple GPUs` section.
path = 'OpenGVLab/InternVL2-8B'
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True).eval().cuda("cuda:1")  # note the model is loaded to cuda:1 not cuda:0!
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)

generation_config = dict(max_new_tokens=1024, do_sample=False)
# pure-text conversation (纯文本对话)
question = 'Hello, who are you?'
response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
print(f'User: {question}\nAssistant: {response}')

Environment

> Which model are you using?
`InternVL2-8B` run in one GPU (RTX 4090)
> How you installed PyTorch [e.g., pip, conda, source]
via conda.

# install gcc, gxx compiler
conda install gcc=9 gxx=9 cxx-compiler -y -c conda-forge

# install pytorch, cuda and other dependencies
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia -y
conda install nvidia/label/cuda-11.8.0::cuda -y

> Other environment variables that may be related
Others should be fine.

Error traceback

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 100
     98 # pure-text conversation (纯文本对话)
     99 question = 'Hello, who are you?'
--> 100 response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
    101 print(f'User: {question}\nAssistant: {response}')

File ~/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_internvl_chat.py:285, in InternVLChatModel.chat(self, tokenizer, pixel_values, question, generation_config, history, return_history, num_patches_list, IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, verbose)
    283 attention_mask = model_inputs['attention_mask'].cuda()
    284 generation_config['eos_token_id'] = eos_token_id
--> 285 generation_output = self.generate(
    286     pixel_values=pixel_values,
    287     input_ids=input_ids,
    288     attention_mask=attention_mask,
    289     **generation_config
    290 )
    291 response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
    292 response = response.split(template.sep)[0].strip()

File ~/miniconda3/envs/vlm/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/.cache/huggingface/modules/transformers_modules/InternVL2-8B/modeling_internvl_chat.py:333, in InternVLChatModel.generate(self, pixel_values, input_ids, attention_mask, visual_features, generation_config, output_hidden_states, return_dict, **generate_kwargs)
    331     input_embeds = input_embeds.reshape(B, N, C)
    332 else:
--> 333     input_embeds = self.language_model.get_input_embeddings()(input_ids)
    335 outputs = self.language_model.generate(
    336     inputs_embeds=input_embeds,
    337     attention_mask=attention_mask,
   (...)
    342     **generate_kwargs,
    343 )
    345 return outputs

File ~/miniconda3/envs/vlm/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/vlm/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/miniconda3/envs/vlm/lib/python3.10/site-packages/torch/nn/modules/sparse.py:162, in Embedding.forward(self, input)
    161 def forward(self, input: Tensor) -> Tensor:
--> 162     return F.embedding(
    163         input, self.weight, self.padding_idx, self.max_norm,
    164         self.norm_type, self.scale_grad_by_freq, self.sparse)

File ~/miniconda3/envs/vlm/lib/python3.10/site-packages/torch/nn/functional.py:2233, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2227     # Note [embedding_renorm set_grad_enabled]
   2228     # XXX: equivalent to
   2229     # with torch.no_grad():
   2230     #   torch.embedding_renorm_
   2231     # remove once script supports set_grad_enabled
   2232     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2233 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
hshjerry commented 3 months ago

I try this CUDA_VISIBLE_DEVICES=xx before command. :)

ControlNet commented 3 months ago

I try this CUDA_VISIBLE_DEVICES=xx before command. :)

Hi hshjerry, thank you for your reply. Yeah if only expose 1 GPU to the system, it will work. However, we're working on a compositional system which load multiple models to multiple GPUs, therefore we have to move each model to corresponding "cuda:x", as the environment variable CUDA_VISIBLE_DEVICES is shared in the process.