Open aliencaocao opened 7 months ago
Based on my script here it should be quite out-of-the-box to compile and run it, and I do get about 4x speed up:
import os
from contextlib import contextmanager
from functools import partial
from time import perf_counter
from typing import Optional
@contextmanager
def catchtime(s) -> float:
start = perf_counter()
yield lambda: perf_counter() - start
print(f'Time of {s=}: {perf_counter() - start:.3f} seconds')
import requests
import torch
from PIL import Image
from tqdm import tqdm
from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor, StaticCache
# noinspection PyProtectedMember
torch._inductor.config.coordinate_descent_tuning = True
# noinspection PyProtectedMember
torch._inductor.config.triton.unique_kernel_names = True
# noinspection PyProtectedMember
torch._inductor.config.fx_graph_cache = True
os.environ["TOKENIZERS_PARALLELISM"] = "true"
MODEL_NAME = 'models/llava-v1.6-mistral-7b-hf'
def mem(): return torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, torch.cuda.memory_allocated() / 1024 / 1024 / 1024
assert torch.cuda.is_available()
device = "cuda"
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
"""Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L40C1-L42C82"""
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
"""Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L44C1-L52C17"""
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
"""Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L54C1-L57C27"""
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def decode_one_tokens(model, cur_token, cache_position):
"""Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L64C1-L68C45"""
logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache=True)[0]
new_token = sample(logits, temperature=0)[0]
return new_token
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
def gen(model, inputs, iters=100):
print(inputs['input_ids'].shape, inputs['image_sizes'])
generated_ids = torch.zeros((1, iters), dtype=torch.int, device=device)
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=False):
output = model(**inputs)
seq_len, logits = output.loss, output.logits
cache_position = torch.tensor([seq_len], device=device)
input_id = sample(logits, temperature=0)[0]
generated_ids[:, 0] = input_id[:, 0]
gen_pos = torch.tensor([1], device=device)
print('post-1st ', mem())
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
for i in tqdm(range(iters - 1)):
input_id = decode_one_tokens(model.language_model, input_id.clone(), cache_position)
generated_ids.index_copy_(1, gen_pos, input_id)
cache_position += 1
gen_pos += 1
print('post-last ', mem())
return generated_ids
with torch.inference_mode():
processor = LlavaNextProcessor.from_pretrained(MODEL_NAME)
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
torch.cuda.memory._record_memory_history()
print('pre-model', mem())
model = LlavaNextForConditionalGeneration.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(device)
print('pre-cache', mem())
static_cache = partial(StaticCache, dtype=torch.float16)
model.language_model._setup_cache(static_cache, max_batch_size=1, max_cache_len=4096)
print('pre-comp ', mem())
model.language_model.compile(mode='reduce-overhead', fullgraph=True)
model.vision_tower.compile(mode='reduce-overhead', fullgraph=True)
print('pre-proc ', mem())
inputs = processor(prompt, image, return_tensors="pt").to(device)
print('pre-gen1 ', mem())
with catchtime('first compile gen:'):
out = gen(model, inputs, iters=10)
# print(out)
print(processor.decode(out[0], skip_special_tokens=True))
print('pre-gen2 ', mem())
with catchtime('second compile gen:'):
out = gen(model, inputs, iters=100)
# print(out)
print(processor.decode(out[0], skip_special_tokens=True))
# torch.cuda.memory._dump_snapshot("snapshot_full.pickle")
But the issue here is compiling a full fp16 requires more than 16GB vram which is more than what I have for production.
Hey, apologies for the late response! I will look into this and get back to you soon :)
I am trying to use this project with a vision-language model like https://huggingface.co/docs/transformers/en/model_doc/llava_next but currently this repo does not support vision part of the model. I have a separate script that works by just splitting the vision tower and compile them separately. Do you think it will be possible to do the same using your project? My separate script is not fully using gptfast yet especially the int8 part so I really wanted to use your awesome work here.
I am using https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf specifically.