MDK8888 / GPTFast

Accelerate your Hugging Face Transformers 7.6-9x. Native to Hugging Face and PyTorch.
Apache License 2.0
686 stars 65 forks source link

Possible to use with a VL model like LLAVA? #16

Open aliencaocao opened 7 months ago

aliencaocao commented 7 months ago

I am trying to use this project with a vision-language model like https://huggingface.co/docs/transformers/en/model_doc/llava_next but currently this repo does not support vision part of the model. I have a separate script that works by just splitting the vision tower and compile them separately. Do you think it will be possible to do the same using your project? My separate script is not fully using gptfast yet especially the int8 part so I really wanted to use your awesome work here.

I am using https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf specifically.

aliencaocao commented 7 months ago

Based on my script here it should be quite out-of-the-box to compile and run it, and I do get about 4x speed up:

import os
from contextlib import contextmanager
from functools import partial
from time import perf_counter
from typing import Optional

@contextmanager
def catchtime(s) -> float:
    start = perf_counter()
    yield lambda: perf_counter() - start
    print(f'Time of {s=}: {perf_counter() - start:.3f} seconds')

import requests
import torch
from PIL import Image
from tqdm import tqdm

from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor, StaticCache

# noinspection PyProtectedMember
torch._inductor.config.coordinate_descent_tuning = True
# noinspection PyProtectedMember
torch._inductor.config.triton.unique_kernel_names = True
# noinspection PyProtectedMember
torch._inductor.config.fx_graph_cache = True

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MODEL_NAME = 'models/llava-v1.6-mistral-7b-hf'

def mem(): return torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, torch.cuda.memory_allocated() / 1024 / 1024 / 1024

assert torch.cuda.is_available()
device = "cuda"

def multinomial_sample_one_no_sync(probs_sort):  # Does multinomial sampling without a cuda synchronization
    """Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L40C1-L42C82"""
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)

def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    """Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L44C1-L52C17"""
    logits = logits / max(temperature, 1e-5)
    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs

def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    """Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L54C1-L57C27"""
    probs = logits_to_probs(logits[:, -1], temperature, top_k)
    idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs

def decode_one_tokens(model, cur_token, cache_position):
    """Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L64C1-L68C45"""
    logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache=True)[0]
    new_token = sample(logits, temperature=0)[0]
    return new_token

decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)

def gen(model, inputs, iters=100):
    print(inputs['input_ids'].shape, inputs['image_sizes'])
    generated_ids = torch.zeros((1, iters), dtype=torch.int, device=device)
    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=False):
        output = model(**inputs)
        seq_len, logits = output.loss, output.logits
        cache_position = torch.tensor([seq_len], device=device)
        input_id = sample(logits, temperature=0)[0]
        generated_ids[:, 0] = input_id[:, 0]
    gen_pos = torch.tensor([1], device=device)
    print('post-1st  ', mem())
    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
        for i in tqdm(range(iters - 1)):
            input_id = decode_one_tokens(model.language_model, input_id.clone(), cache_position)
            generated_ids.index_copy_(1, gen_pos, input_id)
            cache_position += 1
            gen_pos += 1
    print('post-last ', mem())
    return generated_ids

with torch.inference_mode():
    processor = LlavaNextProcessor.from_pretrained(MODEL_NAME)
    url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
    image = Image.open(requests.get(url, stream=True).raw)
    prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"

    torch.cuda.memory._record_memory_history()
    print('pre-model', mem())
    model = LlavaNextForConditionalGeneration.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(device)
    print('pre-cache', mem())
    static_cache = partial(StaticCache, dtype=torch.float16)
    model.language_model._setup_cache(static_cache, max_batch_size=1, max_cache_len=4096)
    print('pre-comp ', mem())
    model.language_model.compile(mode='reduce-overhead', fullgraph=True)
    model.vision_tower.compile(mode='reduce-overhead', fullgraph=True)
    print('pre-proc ', mem())
    inputs = processor(prompt, image, return_tensors="pt").to(device)

    print('pre-gen1 ', mem())
    with catchtime('first compile gen:'):
        out = gen(model, inputs, iters=10)
        # print(out)
        print(processor.decode(out[0], skip_special_tokens=True))
    print('pre-gen2 ', mem())
    with catchtime('second compile gen:'):
        out = gen(model, inputs, iters=100)
        # print(out)
        print(processor.decode(out[0], skip_special_tokens=True))
    # torch.cuda.memory._dump_snapshot("snapshot_full.pickle")

But the issue here is compiling a full fp16 requires more than 16GB vram which is more than what I have for production.

MDK8888 commented 7 months ago

Hey, apologies for the late response! I will look into this and get back to you soon :)