intel-analytics / ipex-llm

Accelerate local LLM inference and finetuning (LLaMA, Mistral, ChatGLM, Qwen, Baichuan, Mixtral, Gemma, Phi, MiniCPM, etc.) on Intel XPU (e.g., local PC with iGPU and NPU, discrete GPU such as Arc, Flex and Max); seamlessly integrate with llama.cpp, Ollama, HuggingFace, LangChain, LlamaIndex, GraphRAG, DeepSpeed, vLLM, FastChat, Axolotl, etc.
Apache License 2.0
6.55k stars 1.25k forks source link

Too long warm-up time for `TorchScript` models #7062

Open liangs6212 opened 1 year ago

liangs6212 commented 1 year ago

The initial loading of Torch Script models requires a long warm-up time, especially for the first and second calls.

To Reproduce

import torch
from bigdl.chronos.forecaster import TCNForecaster
from bigdl.chronos.data import get_public_dataset
from bigdl.nano.pytorch.inference import InferenceOptimizer

train, _, test = get_public_dataset("nyc_taxi")
dummy_data = torch.randn(1, 48, 1), torch.randn(1, 5, 1)
_loader = test.to_torch_data_loader(lookback=48, horizon=5)
tcn = TCNForecaster.from_tsdataset(train, past_seq_len=48, future_seq_len=5)
tcn.fit(train, epochs=5)
_jit_fp32 = InferenceOptimizer.trace(model=tcn.internal, input_sample=dummy_data, accelerator="jit", use_ipex=True, thread_num=1)

# first time:
with torch.no_grad():
    %time yhat = tuple(_jit_fp32(*i) for i in _loader)
# CPU times: user 618 ms, sys: 10.9 ms, total: 629 ms
# Wall time: 251 ms

# second time:
with torch.no_grad():
    %time yhat = tuple(_jit_fp32(*i) for i in _loader)
# CPU times: user 859 ms, sys: 1.18 ms, total: 860 ms
# Wall time: 456 ms

# other
with torch.no_grad():
    %time yhat = tuple(_jit_fp32(*i) for i in _loader)
# CPU times: user 411 ms, sys: 6.31 ms, total: 417 ms
# Wall time: 43.8 ms

Server: i9-7900(ubuntu-22.04 LTS) Version: Python 3.7.13 Pytorch: 1.12.1 Lightning AI: 1.6.4

After comparing the inference times, the first two inference times are longer than the subsequent inference times. This problem occurs whenever the user calls trace or load and is not described in the torch's documentation. some similar issue: https://github.com/triton-inference-server/pytorch_backend/pull/24/files https://github.com/pytorch/pytorch/issues/57894

Solution

torch.jit.optimize_execution(False) gets around this problem, but it is not mentioned in the documentation.

with torch.no_grad():
    with torch.jit.optimized_execution(False):
        %time yhat = tuple(_jit_fp32(*i) for i in _loader)
# CPU times: user 530 ms, sys: 15.8 ms, total: 546 ms
# Wall time: 49.6 ms

In terms of results, optimize_execution also causes a small performance loss. Do we need to fix this? It seems like all of our forecaster have this problem, not just the autoformer. @TheaperDeng @rnwang04 @plusbang

liangs6212 commented 1 year ago

Two possible solutions are shown below:

Chronos:

def predict_with_jit():
    ...
    with torch.jit.optimized_execution(False):
        return _pytorch_fashion_inference(model, ...)

Nano: Add optimized_execution to Contextmanager. Adapted from: https://github.com/pytorch/pytorch/blob/master/torch/jit/_fuser.py#L7

class BaseContextmanager:
    def __init__(shoule_optimize=False):
        self.stored_flag = torch._C._get_graph_executor_optimize()
        self.should_optimize = should_optimize
        ...
    def enter():
        self.no_grad.__enter__()
        torch._C._set_graph_executor_optimize(self.should_optimize)
        ...
    def exit():
        torch._C._set_graph_executor_optimize(self.stored_flag)
        self.no_grad.__exit__(exc_type, exc_value, exc_tb)
        ...
TheaperDeng commented 1 year ago

I think we could provide a "warmed-up" model to our users once uses call InferenceOptimizer.trace with accelerator="jit". We do this in InferenceOptimizer.optimize while not in InferenceOptimizer.trace.

TheaperDeng commented 1 year ago

I am not really sure if we need to add this torch.jit.optimized_execution(False): in our context manager since it is really bad documented

liangs6212 commented 1 year ago

I think we could provide a "warmed-up" model to our users once uses call InferenceOptimizer.trace with accelerator="jit". We do this in InferenceOptimizer.optimize while not in InferenceOptimizer.trace.

btw, when the user calls load, it will change back to the "unwarmed-up" model.

liangs6212 commented 1 year ago

I am not really sure if we need to add this torch.jit.optimized_execution(False): in our context manager since it is really bad documented

What about chronos? I think forecaster needs this torch.jit.optimized_execution(False).

TheaperDeng commented 1 year ago

I am not really sure if we need to add this torch.jit.optimized_execution(False): in our context manager since it is really bad documented

What about chronos? I think forecaster needs this torch.jit.optimized_execution(False).

I think put it to chronos is very reasonable since we focus on a very specific solution(model) for each chronos forecaster