OpenGVLab / InternVL

[CVPR 2024 Oral] InternVL Family: A Pioneering Open-Source Alternative to GPT-4o. 接近GPT-4o表现的开源多模态对话模型
https://internvl.readthedocs.io/en/latest/
MIT License
5.48k stars 425 forks source link

Please Add Support For Triton Flash Attention Inference #256

Open radna0 opened 3 months ago

radna0 commented 3 months ago

Please add config.attn_config['attn_impl'] = 'triton' for Triton Flash Attention Inference

import torch
from PIL import Image
from transformers import AutoModel, AutoConfig, CLIPImageProcessor

# Define the model name
model_name = 'OpenGVLab/InternViT-6B-224px'

# Load the model configuration and set attention implementation to Triton
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config.attn_config['attn_impl'] = 'triton'
config.init_device = 'cuda:0'  # For fast initialization directly on GPU

# Load the model with the updated configuration
model = AutoModel.from_pretrained(
    model_name,
    config=config,
    torch_dtype=torch.bfloat16,  # Load model weights in bfloat16
    low_cpu_mem_usage=True,
    trust_remote_code=True
).cuda().eval()

# Load and process the image
image = Image.open('./examples/image1.jpg').convert('RGB')
image_processor = CLIPImageProcessor.from_pretrained(model_name)
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()

# Run inference
outputs = model(pixel_values)

# Print outputs for debugging purposes
print(outputs)

If running normally, the flash attention package is required, but I have Triton Installed ImportError: This modeling file requires the following packages that were not found in your environment: flash_attn. Run `pip install flash_attn

import torch
from PIL import Image
from transformers import AutoModel, CLIPImageProcessor

model = AutoModel.from_pretrained(
    'OpenGVLab/InternViT-6B-224px',
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True).cuda().eval()

image = Image.open('./examples/image1.jpg').convert('RGB')

image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-6B-224px')

pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()

outputs = model(pixel_values)
zmyzxb commented 1 month ago

This issue has been inactive for over two weeks. If the problem is still unresolved, please feel free to open a new issue to ask your question. Thank you.