pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.67k stars 514 forks source link

I try to speed up with llava,but this it slower then eager mode,why? #92

Open bleedingfight opened 9 months ago

bleedingfight commented 9 months ago

gpt-fast used torch.compile methods and achieved significant acceleration.so i wan’t change my llava model with torch.compile.llava and llama are similar,we need to use clip model processed image and then llama’s embedding layer get it to vector,prompt also be emit to embedding and 32 decode of llama will process a mixture data of images and text.I won’t go into detail about the llava and llama algorithms here because they are not important. What’s important is the method of accelerating computation。gpt-fast just compiled decode_one_token(include forward and sample) and then achieved a very high acceleration ratio.I applied the same method to my model, but the speed was very slow, even slower than in eager mode。The compilation did not report any errors, but each round of calculations took almost the same time and was relatively slow.

import itertools
import sys
import os
import time
from pathlib import Path
from typing import Optional, Tuple

import torch
import torch._dynamo.config
import torch._inductor.config

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True  # Experimental feature to reduce compilation times, will be on by default in future

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from sentencepiece import SentencePieceProcessor

from model import Transformer
from tp import maybe_init_dist
from functools import wraps
import contextlib
from simple import sample

def cost_time(fn):

    @wraps(fn)
    def wrap_cost(*kw, **kwargs):
        start = time.perf_counter()
        out = fn(*kw, **kwargs)
        end = time.perf_counter()
        print(f"{fn.__name__} cost time:{end-start:.4f}(s)")
        return out

    return wrap_cost

def logits_to_probs(logits,
                    temperature: float = 1.0,
                    top_k: Optional[int] = None):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs

def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor,
            **sampling_kwargs) -> torch.Tensor:
    # input_pos: [B, S]
    logits = model(x, input_pos)
    return sample(logits, **sampling_kwargs)[0]

@cost_time
def model_forward(model, x, input_pos):
    return model(x, input_pos)

@cost_time
def encode_tokens(tokenizer, string, bos=True, device="cuda"):
    tokens = tokenizer.encode(string)
    if bos:
        tokens = [tokenizer.bos_id()] + tokens
    return torch.tensor(tokens, dtype=torch.int, device=device)

class LLamaModel(object):

    def __init__(self, **kwargs):
        self.prompt = kwargs.get("prompt")
        self.interactive = kwargs.get("interactive", False)
        self.num_samples = kwargs.get("num_samples", 1)
        self.max_new_tokens = kwargs.get("max_new_tokens", 200)
        self.top_k = kwargs.get("top_k", 200)
        self.temperature = kwargs.get("temperature", 0.8)
        self.checkpoint_path = kwargs.get("checkpoint_path")
        self.compile = kwargs.get("compile_prefill", False)
        self.profile = kwargs.get("profile", None)
        self.speculate_k = kwargs.get("speculate_k", 5)
        self.draft_checkpoint_path = kwargs.get("draft_checkpoint_path", None)
        self.compile_prefill = kwargs.get("compile_prefill", False)
        self.device = kwargs.get("device", "cuda")
        self.precision = kwargs.get("precision", torch.float16)
        self.use_tp = kwargs.get("use_tp", False)
        self.model = self._load_model(self.checkpoint_path, self.device,
                                      self.precision, self.use_tp)

    def decode_n_tokens(
        self,
        cur_token: torch.Tensor,
        input_pos: torch.Tensor,
        num_new_tokens: int,
        callback=lambda _: _,
        **sampling_kwargs,
    ):
        new_tokens, new_probs = [], []
        with torch.profiler.profile() as prof:
            for i in range(num_new_tokens):
                with torch.backends.cuda.sdp_kernel(
                        enable_flash=False,
                        enable_mem_efficient=False,
                        enable_math=True
                ):  # Actually better for Inductor to codegen attention here
                    t0 = time.time()
                    next_token, next_prob = decode_one_token(
                        cur_token, input_pos, **sampling_kwargs)
                    t1 = time.time()
                    print(f"compiled next_token cost time:{t1-t0:.4f}")
                input_pos += 1
                new_tokens.append(next_token.clone())
                callback(new_tokens[-1])
                new_probs.append(next_prob.clone())
                cur_token = next_token.view(1, -1)
        prof.export_chrome_trace("fast_trace.json")
        return new_tokens, new_probs

    @cost_time
    @torch.no_grad()
    def generate(
        self,
        prompt: torch.Tensor,
        max_new_tokens: int,
        *,
        draft_model: Transformer,
        speculate_k: Optional[int] = 8,
        callback=lambda x: x,
        **sampling_kwargs,
    ) -> torch.Tensor:
        """
        Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
        """

        t0 = time.perf_counter()
        is_speculative = draft_model is not None
        T = prompt.size(0)
        T_new = T + max_new_tokens
        max_seq_length = min(T_new, self.model.config.block_size)

        device, dtype = prompt.device, prompt.dtype
        max_seq_length = (max_seq_length + speculate_k +
                          1 if is_speculative else max_seq_length)
        with torch.device(device):
            self.model.setup_caches(max_batch_size=1,
                                    max_seq_length=max_seq_length)

        input_pos = torch.arange(0, T, device=device)
        t0 = time.time()
        logits = self.model(prompt.view(1, -1), input_pos)
        t1 = time.time()
        print(f"Forward:{t1-t0:.4f}")

        # create an empty tensor of the expected final shape and fill in the current tokens
        empty = torch.empty(T_new, dtype=dtype, device=device)
        empty[:T] = prompt
        seq = empty
        next_token = sample(logits, **sampling_kwargs)[0]
        seq[T] = next_token

        input_pos = torch.tensor([T], device=device, dtype=torch.int)
        accept_counts = [0] * (speculate_k + 1)

        generated_tokens, _ = self.decode_n_tokens(
            next_token.view(1, -1),
            input_pos,
            max_new_tokens - 1,
            callback=callback,
            **sampling_kwargs,
        )
        seq[T + 1:] = torch.cat(generated_tokens)

        generate_stats = {"accept_counts": accept_counts}
        return seq, generate_stats

    def _load_model(self, checkpoint_path, device, precision, use_tp):
        t0 = time.time()
        with torch.device("meta"):
            model = Transformer.from_name(checkpoint_path.parent.name)

        if "int8" in str(checkpoint_path):
            print("Using int8 weight-only quantization!")
            from quantize import WeightOnlyInt8QuantHandler

            simple_quantizer = WeightOnlyInt8QuantHandler(model)
            model = simple_quantizer.convert_for_runtime()

        if "int4" in str(checkpoint_path):
            print("Using int4 quantization!")
            path_comps = checkpoint_path.name.split(".")
            assert path_comps[-2].startswith("g")
            groupsize = int(path_comps[-2][1:])
            from quantize import WeightOnlyInt4QuantHandler

            simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
            model = simple_quantizer.convert_for_runtime()

        checkpoint = torch.load(str(checkpoint_path),
                                mmap=True,
                                weights_only=True)
        model.load_state_dict(checkpoint, assign=True)

        if use_tp:
            from tp import apply_tp

            print("Applying tensor parallel to model ...")
            apply_tp(model)

        model = model.to(device=device, dtype=precision)
        torch.cuda.synchronize()
        print(f"Time to load model: {time.time() - t0:.02f} seconds")
        return model.eval()

    def forward(self, x, input_pos):
        return self.model(x, input_pos)

    def decode_one_token(
        self,
        x: torch.Tensor,
        input_pos: torch.Tensor,
        **sampling_kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # input_pos: [B, 1]
        assert input_pos.shape[-1] == 1
        logits = self.forward(x, input_pos)
        return sample(logits, **sampling_kwargs)

    def get_model_size(self, model):
        model_size = sum([
            p.numel() * p.dtype.itemsize
            for p in itertools.chain(model.parameters(), model.buffers())
        ])
        return model_size

    @cost_time
    def pipeline(self):
        num_samples = self.num_samples
        profile = self.profile
        max_new_tokens = self.max_new_tokens
        prompt = self.prompt

        tokenizer_path = self.checkpoint_path.parent / "tokenizer.model"
        assert tokenizer_path.is_file(), tokenizer_path

        global print
        t0 = time.time()
        draft_model = None

        tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
        encoded = encode_tokens(tokenizer,
                                prompt,
                                bos=True,
                                device=self.device)
        prompt_length = encoded.size(0)

        torch.manual_seed(1234)
        model_size = self.get_model_size(self.model)

        if compile:
            global decode_one_token
            from torch._dynamo.utils import CompileProfiler

            with CompileProfiler() as prof:
                # decode_one_token = self.decode_one_token
                decode_one_token = torch.compile(self.decode_one_token,
                                                 mode="reduce-overhead",
                                                 fullgraph=True)
            print(prof.report())

        aggregate_metrics = {
            "tokens_per_sec": [],
            "accept_counts": [],
        }
        start = -1 if compile else 0

        for i in range(start, num_samples):
            torch.cuda.synchronize()
            callback = lambda x: x
            t0 = time.perf_counter()

            prof = contextlib.nullcontext()
            with prof:
                y, metrics = self.generate(
                    encoded,
                    max_new_tokens,
                    draft_model=draft_model,
                    speculate_k=self.speculate_k,
                    callback=callback,
                    temperature=self.temperature,
                    top_k=self.top_k,
                )
                aggregate_metrics["accept_counts"].append(
                    metrics["accept_counts"])
            if i == -1:
                print(
                    f"Compilation time: {time.perf_counter() - t0:.2f} seconds"
                )
                continue
            if hasattr(prof, "export_chrome_trace"):
                prof.export_chrome_trace(f"{profile}.json")
            torch.cuda.synchronize()
            t = time.perf_counter() - t0

            OUTPUT = os.environ.get("OUTPUT", True)
            OUTPUT = eval(OUTPUT) if isinstance(OUTPUT, str) else OUTPUT
            if OUTPUT:
                print(tokenizer.decode(y.tolist()))
            else:
                print()
            tokens_generated = y.size(0) - prompt_length
            tokens_sec = tokens_generated / t
            aggregate_metrics["tokens_per_sec"].append(tokens_sec)
            print(
                f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
            )
            print(
                f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
            )

        print(
            f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}"
        )
        print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

if __name__ == "__main__":
    params = {
        "prompt":
        "请用中文描述一下图像?",
        # "请用中文描述一下图像?" * 43 + "你好世界,中国制造,中",
        "interactive":
        False,
        "num_samples":
        2,
        "max_new_tokens":
        100,
        "top_k":
        50,
        "temperature":
        0.8,
        "checkpoint_path":
        Path(
            "/home/user/workspace/llama-7B/model.pth"),
        "compile":
        True,
        "compile_prefill":
        False,
        "profile":
        None,
        "speculate_k":
        5,
        "draft_checkpoint_path":
        None,
    }
    llama = LLamaModel(**params)
    llama.pipeline()

my changed llava code:

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import sys
import os
import time
from pathlib import Path
from torch._dynamo.utils import CompileProfiler
from typing import Optional, Tuple
from torch import nn

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch._dynamo.config
import torch._inductor.config
from clip_utils import (
    load_vison_tower,
    load_mmprojector,
    load_llama_tokenizer,
    merge_input_ids_with_image_features,
)
import logging

Logger = logging.getLogger(__name__)
FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(message)s"
logging.basicConfig(format=FORMAT)
Logger.setLevel(logging.DEBUG)

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True  # Experimental feature to reduce compilation times, will be on by default in future

torch._dynamo.config.replay_record_enabled = True
torch._dynamo.config.verbose = True
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from model import Transformer
from functools import wraps
import contextlib

def multinomial_sample_one_no_sync(
    probs_sort, ):  # Does multinomial sampling without a cuda synchronization
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1,
                        keepdim=True).to(dtype=torch.int)

def logits_to_probs(logits,
                    temperature: float = 1.0,
                    top_k: Optional[int] = None):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs

def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    probs = logits_to_probs(logits[0, -1], temperature, top_k)
    idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs

def sample(next_token_scores):
    probs = nn.functional.softmax(next_token_scores, dim=-1)
    # next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
    next_tokens = multinomial_sample_one_no_sync(probs)
    return next_tokens

def cost_time(fn):

    @wraps(fn)
    def wrap_cost(*kw, **kwargs):
        start = time.perf_counter()
        out = fn(*kw, **kwargs)
        end = time.perf_counter()
        print(f"{fn.__name__} cost time:{end-start:.4f}(s)")
        return out

    return wrap_cost

class LLamaModel(object):

    def __init__(self, **kwargs):
        self.prompt = kwargs.get("prompt")
        self.interactive = kwargs.get("interactive", False)
        self.num_samples = kwargs.get("num_samples", 1)
        self.max_new_tokens = kwargs.get("max_new_tokens", 200)
        self.top_k = kwargs.get("top_k", 200)
        self.temperature = kwargs.get("temperature", 0.8)
        self.checkpoint_path = kwargs.get("checkpoint_path")
        self.compile = kwargs.get("compile_prefill", False)
        self.profile = kwargs.get("profile", None)
        self.speculate_k = kwargs.get("speculate_k", 5)
        self.draft_checkpoint_path = kwargs.get("draft_checkpoint_path", None)
        self.compile_prefill = kwargs.get("compile_prefill", False)
        self.device = kwargs.get("device", "cuda")
        self.precision = kwargs.get("precision", torch.float16)
        self.use_tp = kwargs.get("use_tp", False)

        self.clip_config = kwargs.get("clip_config")
        self.mm_config = kwargs.get("mm_config")

        self.vision_tower, self.image_processor = load_vison_tower(
            self.clip_config)
        self.mm_projector = load_mmprojector(self.mm_config,
                                             dtype=self.precision)
        self.model, self.tokenizer = self._load_model(self.checkpoint_path,
                                                      self.device,
                                                      self.precision,
                                                      self.use_tp)

    def decode_n_tokens(
        self,
        input_ids: torch.LongTensor,
        max_new_tokens: Optional[int] = None,
        image_features=None,
    ):
        attention_mask = torch.ones_like(input_ids,
                                         dtype=torch.int32,
                                         device=self.device)

        input_pos = torch.arange(0, input_ids.shape[0],
                                 device=self.device).unsqueeze(dim=0)

        Logger.info(f"Decode {max_new_tokens} tokens")

        with torch.profiler.profile() as prof:
            for step in range(max_new_tokens):
                # token to embedding
                input_embed = self.model.tok_embeddings(input_ids)
                # 新的input_ids下需要计算新的image_embd
                (
                    inputs_embeds,
                    attention_mask,
                    input_pos,
                ) = merge_input_ids_with_image_features(
                    image_features,
                    input_embed,
                    input_ids,
                    attention_mask,
                    input_pos,
                )
                with torch.backends.cuda.sdp_kernel(enable_flash=False,
                                                    enable_mem_efficient=False,
                                                    enable_math=True):
                    # next_tokens
                    t0 = time.perf_counter()
                    next_tokens = decode_one_token(
                        input_pos,
                        inputs_embeds,
                        input_ids,
                    )
                    t1 = time.perf_counter()
                    Logger.debug(
                        f"Compiled decode_one_token cost time:{t1-t0:.5f}(s)")

                # input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
                input_ids = torch.cat([input_ids, next_tokens], dim=-1)
        prof.export_chrome_trace("opt_trace.json")

        return input_ids

    @cost_time
    @torch.no_grad()
    def generate(
        self,
        input_ids,
        input_pos: torch.Tensor,
        input_embed: torch.Tensor,
        max_new_tokens: int,
        image_outputs: None,
        batch_size=1,
        **sampling_kwargs,
    ) -> torch.Tensor:
        # 预处理需要的参数
        top_k = sampling_kwargs.get("top_k")
        min_tokens_to_keep = sampling_kwargs.get("min_tokens_to_keep")
        eos_token_id = sampling_kwargs.get("eos_token_id")

        # prompt token长度
        T = len(input_ids)
        # 缓存空间用于存放生成的token,其长度必须包含原始prompt的token和当前生成的token
        T_new = T + max_new_tokens
        t0 = time.perf_counter()
        # 生成的最大token不能超过模型的最大输出
        max_seq_length = min(T + max_new_tokens + 24 * 24,
                             self.model.config.block_size)
        device, dtype = input_pos.device, input_pos.dtype
        with torch.device(device):
            # 设置模型计算需要的缓存
            self.model.setup_caches(max_batch_size=1,
                                    max_seq_length=max_seq_length)

        # llama 计算前项输出
        logits = self.model(None, input_pos, input_embed)
        print(
            f"==> Forward:{time.perf_counter()-t0:.4f}:current logits shape = {logits.shape}"
        )

        # 缓存输出结果
        empty = torch.empty(T_new, dtype=dtype, device=device)
        # 填入输入token
        empty[:T] = input_ids

        if len(input_ids.shape) == 1:
            input_ids = input_ids.unsqueeze(axis=0)

        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]

        input_pos = torch.arange(0, input_ids.shape[0],
                                 device=self.device).unsqueeze(dim=0)

        # from line_profiler import LineProfiler

        # lp = LineProfiler()
        # self.decode_n_tokens = lp(self.decode_n_tokens)
        selected_image_feature = image_outputs.hidden_states[-2]
        selected_image_feature = selected_image_feature[:, 1:]
        image_features = self.mm_projector(selected_image_feature)

        input_ids = self.decode_n_tokens(
            input_ids,
            max_new_tokens,
            image_features,
        )
        # lp.print_stats()
        print(f"next:{self.decode_token_to_text(input_ids)}")
        # 将输出的token加入输出序列
        accept_counts = 1
        generate_stats = {"accept_counts": accept_counts}
        return input_ids, generate_stats

    def _load_model(self, checkpoint_path, device, precision, use_tp):
        t0 = time.time()
        with torch.device("meta"):
            model = Transformer.from_name(checkpoint_path.parent.name)

        if "int8" in str(checkpoint_path):
            print("Using int8 weight-only quantization!")
            from quantize import WeightOnlyInt8QuantHandler

            simple_quantizer = WeightOnlyInt8QuantHandler(model)
            model = simple_quantizer.convert_for_runtime()

        if "int4" in str(checkpoint_path):
            print("Using int4 quantization!")
            path_comps = checkpoint_path.name.split(".")
            assert path_comps[-2].startswith("g")
            groupsize = int(path_comps[-2][1:])
            from quantize import WeightOnlyInt4QuantHandler

            simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
            model = simple_quantizer.convert_for_runtime()

        checkpoint = torch.load(str(checkpoint_path),
                                mmap=True,
                                weights_only=True)
        model.load_state_dict(checkpoint, assign=True)

        if use_tp:
            from tp import apply_tp

            print("Applying tensor parallel to model ...")
            apply_tp(model)

        model = model.to(device=device, dtype=precision)
        torch.cuda.synchronize()
        print(f"Time to load model: {time.time() - t0:.02f} seconds")

        tokenizer_path = checkpoint_path.parent / "tokenizer.model"
        assert tokenizer_path.is_file(), tokenizer_path

        # tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
        tokenizer = load_llama_tokenizer(
            "/home/liushuai9/workspace/llava-seg/split_cn/llama-7B")
        torch.manual_seed(1234)
        # model_size = self.get_model_size(self.model)

        return model.eval(), tokenizer

    def decode_one_token(
        self,
        input_pos: torch.Tensor,
        input_embed: torch.Tensor,
        input_ids,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        logits = self.model(None, input_pos.squeeze(), input_embed)
        next_token_logits = logits[:, -1, :]

        # top_k = min(50, next_token_logits.size(-1))  # 检查top_k = 50
        # # Remove all tokens with a probability less than the last token of the top-k
        # indices_to_remove = (next_token_logits
        #                      < torch.topk(next_token_logits, top_k)[0][..., -1,
        #                                                                None])
        # next_token_scores = next_token_logits.masked_fill(
        #     indices_to_remove, -float("inf"))
        # return sample(next_token_scores)
        return sample(next_token_logits)
        # return sample(logits, temperature=0.8, top_k=50)[0]

    def get_model_size(self, model):
        """
        计算模型权重大小
        """
        model_size = sum([
            p.numel() * p.dtype.itemsize
            for p in itertools.chain(model.parameters(), model.buffers())
        ])
        return model_size

    def decode_token_to_text(self, llama_outputs):
        """
        将llama模型的最终输出转换为字符串
        """
        records = []
        for output_ids in llama_outputs:
            record = {
                "generated_text":
                self.tokenizer.decode(
                    output_ids,
                    skip_special_tokens=True,
                )
            }
            records.append(record)
        return records

    @cost_time
    def pipeline(
        self,
        prompt,
        input_pos,
        input_embed,
        max_new_tokens=100,
        pixel_values=None,
    ):
        num_samples = self.num_samples
        input_ids = torch.LongTensor(self.tokenizer(prompt)["input_ids"]).to(
            self.device)

        model_size = self.get_model_size(self.model)

        # 必须在全局作用域下工作
        global decode_one_token
        with CompileProfiler() as prof:
            # decode_one_token = self.decode_one_token
            decode_one_token = torch.compile(
                self.decode_one_token,
                mode="reduce-overhead",
                fullgraph=True,
                backend="inductor",
            )
            print(prof.report())

        aggregate_metrics = {
            "tokens_per_sec": [],
            "accept_counts": [],
        }
        start = -1

        top_k = 50
        min_tokens_to_keep = 1
        pad_token_id = 32001
        eos_token_id = 2
        kwargs = {
            "top_k": top_k,
            "min_tokens_to_keep": min_tokens_to_keep,
            "pad_token_id": pad_token_id,
            "eos_token_id": eos_token_id,
            "temperature": self.temperature,
        }

        image_outputs = self.vision_tower(pixel_values,
                                          output_hidden_states=True)
        for i in range(start, num_samples):
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            prof = contextlib.nullcontext()
            with prof:
                y, metrics = self.generate(
                    input_ids,
                    input_pos,
                    input_embed,
                    max_new_tokens,
                    image_outputs=image_outputs,
                    **kwargs,
                )
                aggregate_metrics["accept_counts"].append(
                    metrics["accept_counts"])
            if i == -1:
                print(
                    f"Compilation time: {time.perf_counter() - t0:.2f} seconds"
                )
                continue
            torch.cuda.synchronize()
            t = time.perf_counter() - t0

            OUTPUT = os.environ.get("OUTPUT", True)
            OUTPUT = eval(OUTPUT) if isinstance(OUTPUT, str) else OUTPUT
            if OUTPUT:
                print(self.decode_token_to_text(y))
            else:
                print()
            tokens_generated = y.size(0)
            tokens_sec = tokens_generated / t
            aggregate_metrics["tokens_per_sec"].append(tokens_sec)
            print(
                f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
            )
            print(
                f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
            )

        print(
            f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}"
        )
        print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

if __name__ == "__main__":
    params = {
        "prompt":
        "请用中文描述一下图像?",
        "interactive":
        False,
        "num_samples":
        1,
        "max_new_tokens":
        100,
        "top_k":
        50,
        "temperature":
        0.8,
        "checkpoint_path":
        Path(
            "/home/user/llama-7B/model.pth"),
        "compile":
        True,
        "clip_config":
        "/home/user/clip-vit-large-patch14-336",

gpt-fast used torch.compile methods and achieved significant acceleration.so i wan’t change my llava model with torch.compile.llava and llama are similar,we need to use clip model processed image and then llama’s embedding layer get it to vector,prompt also be emit to embedding and 32 decode of llama will process a mixture data of images and text.I won’t go into detail about the llava and llama algorithms here because they are not important. What’s important is the method of accelerating computation。gpt-fast just compiled decode_one_token(include forward and sample) and then achieved a very high acceleration ratio.I applied the same method to my model, but the speed was very slow, even slower than in eager mode。The compilation did not report any errors, but each round of calculations took almost the same time and was relatively slow.

all-logs I tried to check the results of the trace,you can download from above url。 image gpt-fast image: image

Why are there two stream in the graph?why gpu kernel include at::function and triton::function?some compile failed?

peace-zy commented 6 months ago

Has this problem been resolved?