huggingface / optimum

🚀 Accelerate training and inference of 🤗 Transformers and 🤗 Diffusers with easy to use hardware optimization tools
https://huggingface.co/docs/optimum/main/
Apache License 2.0
2.6k stars 477 forks source link

onnx optimum ORTOptimizer inference runs slower than setfit.export_onnx runtime.InferenceSession inference #1885

Open geraldstanje opened 6 months ago

geraldstanje commented 6 months ago

System Info

Hi,

i did a test between onnx optimum export + ORTOptimizer inference vs. setfit.export_onnx + onnxruntime.InferenceSession.

it seems that onnx optimum ORTOptimizer inference runs slower than setfit.export_onnx runtime.InferenceSession inference any idea why is that the reason?

i also changed from AutoOptimizationConfig.O2() =AutoOptimizationConfig.O4() - still onnxruntime.InferenceSession is faster.

set train_model = True - to train the finetuned model before and export it. gpu: nvidia T4

output:

python setfit-onnx-optimum-example.py
Repo card metadata block was not found. Setting CardData to empty.
Model size (MB) - 86.68
Accuracy on test set - 0.888
Average latency (ms) - 6.23 +\- 0.51
Framework not specified. Using pt to export the model.
Using the export variant default. Available variants are:
    - default: The default ONNX variant.

***** Exporting submodel 1/1: BertModel *****
Using framework PyTorch: 2.2.1+cu121
Overriding 1 configuration item(s)
        - use_cache -> False
2024-06-02 22:27:53.640590789 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2024-06-02 22:27:53.640623671 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/optimum/onnxruntime/configuration.py:770: FutureWarning: disable_embed_layer_norm will be deprecated soon, use disable_embed_layer_norm_fusion instead, disable_embed_layer_norm_fusion is set to True.
  warnings.warn(
Optimizing model...
Configuration saved in all-MiniLM-L6-v2_auto_opt_O2/ort_config.json
Optimized model saved at: all-MiniLM-L6-v2_auto_opt_O2 (external data format: False; saved all tensor to one file: True)
2024-06-02 22:27:55.548291362 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2024-06-02 22:27:55.548316947 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
Model size (MB) - 86.10
Accuracy on test set - 0.888
Average latency (ms) - 1.83 +\- 0.46
Speedup: 3.40x
2024-06-02 22:27:59.483816381 [W:onnxruntime:, transformer_memcpy.cc:74 ApplyImpl] 2 Memcpy nodes are added to the graph main_graph_ed6a60ecdb95455bac10d5392cf78d36 for CUDAExecutionProvider. It might have negative impact on performance (including unable to run CUDA graph). Set session_options.log_severity_level=1 to see the detail logs before this message.
2024-06-02 22:27:59.485393795 [W:onnxruntime:, session_state.cc:1166 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2024-06-02 22:27:59.485413289 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
providers: ['CUDAExecutionProvider', 'CPUExecutionProvider']
Model size (MB) - 86.23
Accuracy on test set - 0.888
Average latency (ms) - 1.40 +\- 0.17
Speedup: 4.44x

code:

# https://github.com/huggingface/setfit/blob/main/notebooks/setfit-onnx-optimum.ipynb
from pathlib import Path
from time import perf_counter

import evaluate
import numpy as np
import torch
from tqdm.auto import tqdm
import os

import matplotlib.pyplot as plt
import pandas as pd

from setfit import SetFitModel
from setfit import SetFitModel, Trainer, TrainingArguments

from datasets import load_dataset
from setfit.exporters.utils import mean_pooling
from optimum.onnxruntime import ORTModelForFeatureExtraction, AutoOptimizationConfig, ORTOptimizer
from transformers import AutoTokenizer
from setfit.exporters.onnx import export_onnx
import onnxruntime

metric = evaluate.load("accuracy")
train_model = False

class PerformanceBenchmark:
    def __init__(self, model, dataset, optim_type):
        self.model = model
        self.dataset = dataset
        self.optim_type = optim_type

    def compute_accuracy(self):
        preds = self.model.predict(self.dataset["text"])
        labels = self.dataset["label"]
        accuracy = metric.compute(predictions=preds, references=labels)
        print(f"Accuracy on test set - {accuracy['accuracy']:.3f}")
        return accuracy

    def compute_size(self):
        state_dict = self.model.model_body.state_dict()
        tmp_path = Path("model.pt")
        torch.save(state_dict, tmp_path)
        # Calculate size in megabytes
        size_mb = Path(tmp_path).stat().st_size / (1024 * 1024)
        # Delete temporary file
        tmp_path.unlink()
        print(f"Model size (MB) - {size_mb:.2f}")
        return {"size_mb": size_mb}

    def time_model(self, query="that loves its characters and communicates something rather beautiful about human nature"):
        latencies = []
        # Warmup
        for _ in range(10):
            _ = self.model([query])
        # Timed run
        for _ in range(100):
            start_time = perf_counter()
            _ = self.model([query])
            latency = perf_counter() - start_time
            latencies.append(latency)
        # Compute run statistics
        time_avg_ms = 1000 * np.mean(latencies)
        time_std_ms = 1000 * np.std(latencies)
        print(rf"Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f}")
        return {"time_avg_ms": time_avg_ms, "time_std_ms": time_std_ms}

    def run_benchmark(self):
        metrics = {}
        metrics[self.optim_type] = self.compute_size()
        metrics[self.optim_type].update(self.compute_accuracy())
        metrics[self.optim_type].update(self.time_model())
        return metrics

def plot_metrics(perf_metrics):
    df = pd.DataFrame.from_dict(perf_metrics, orient="index")

    for idx in df.index:
        df_opt = df.loc[idx]
        plt.errorbar(
            df_opt["time_avg_ms"],
            df_opt["accuracy"] * 100,
            xerr=df_opt["time_std_ms"],
            fmt="o",
            alpha=0.5,
            ms=df_opt["size_mb"] / 15,
            label=idx,
            capsize=5,
            capthick=1,
        )

    legend = plt.legend(loc="lower right")

    plt.ylim(63, 95)
    # Use the slowest model to define the x-axis range
    xlim = max([metrics["time_avg_ms"] for metrics in perf_metrics.values()]) * 1.2
    plt.xlim(0, xlim)
    plt.ylabel("Accuracy (%)")
    plt.xlabel("Average latency with batch_size=1 (ms)")
    plt.show()

class OnnxPerformanceBenchmark(PerformanceBenchmark):
    def __init__(self, *args, model_path, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_path = model_path

    def compute_size(self):
        size_mb = Path(self.model_path).stat().st_size / (1024 * 1024)
        print(f"Model size (MB) - {size_mb:.2f}")
        return {"size_mb": size_mb}

class OnnxSetFitModel:
    def __init__(self, ort_model, tokenizer, model_head):
        self.ort_model = ort_model
        self.tokenizer = tokenizer
        self.model_head = model_head

    def predict(self, inputs):
        encoded_inputs = self.tokenizer(
            inputs, padding=True, truncation=True, return_tensors="pt"
        ).to(self.ort_model.device)

        outputs = self.ort_model(**encoded_inputs)
        embeddings = mean_pooling(
            outputs["last_hidden_state"], encoded_inputs["attention_mask"]
        )

        if embeddings.is_cuda:
            embeddings = embeddings.cpu()

        embeddings_np = embeddings.numpy()
        return self.model_head.predict(embeddings_np)
        #return self.model_head.predict(embeddings)

    def __call__(self, inputs):
        return self.predict(inputs)

class OnnxSetFitModelV2:
    def __init__(self, onnx_session, tokenizer, model_head):
        self.onnx_session = onnx_session
        self.tokenizer = tokenizer
        self.model_head = model_head

    def predict(self, inputs):
        #model_max_length=512
        #max_length = 512 #512 #model.model_body.max_seq_length

        # Create the same embeddings using the ONNX model
        encoded_inputs = self.tokenizer(
            inputs,
            #max_length=max_length,
            #padding="max_length",
            padding=True,
            truncation=True,
            return_attention_mask=True,
            return_token_type_ids=True,
            return_tensors="np",
        )

        return self.onnx_session.run(None, dict(encoded_inputs))[0]

    def __call__(self, inputs):
        return self.predict(inputs)

def main():
    # Set the TOKENIZERS_PARALLELISM environment variable
    os.environ['TOKENIZERS_PARALLELISM'] = 'false'

    setfit_model_name = "sentence-transformers/all-MiniLM-L6-v2" #"BAAI/bge-small-en-v1.5"

    dataset = load_dataset("SetFit/sst2")
    #dataset
    train_dataset = dataset["train"]
    test_dataset = dataset["validation"] #[0:200]- if onnx.session runs out of memory - we need to limit the test_dataset

    # Evaluate the uploaded model!
    #model = SetFitModel.from_pretrained("dkorat/bge-small-en-v1.5_setfit-sst2-english")
    #pb = PerformanceBenchmark(model=model, dataset=test_dataset, optim_type="bge-small (PyTorch)")
    #perf_metrics = pb.run_benchmark()

    if train_model:
        # Fine-tune the base model and Evaluate!
        # Load pretrained model from the Hub
        model = SetFitModel.from_pretrained(
            setfit_model_name
        )
        args = TrainingArguments(num_iterations=20)

        # Create trainer
        trainer = Trainer(
            model=model, args=args, train_dataset=train_dataset
        )
        # Train!
        trainer.train()

        # Save and push the model to the Hub (change the model name accordingly)
        model.save_pretrained("setfit-test-model-example")

        # Evaluate!
        pb = PerformanceBenchmark(
            model=trainer.model, dataset=test_dataset, optim_type="all-MiniLM-L6-v2 (PyTorch)"
        )
    else:
        model = SetFitModel.from_pretrained(
            "setfit-test-model-example"
        )
        # Evaluate!
        pb = PerformanceBenchmark(
            model=model, dataset=test_dataset, optim_type="all-MiniLM-L6-v2 (PyTorch)"
        )

    perf_metrics = pb.run_benchmark()
    plot_metrics(perf_metrics)

    # Load a PyTorch model and export it to the ONNX format
    ort_model = ORTModelForFeatureExtraction.from_pretrained(
        "setfit-test-model-example",
        export=True,
        provider="CUDAExecutionProvider",
    )

    # Create the optimizer
    optimizer = ORTOptimizer.from_pretrained(ort_model)

    # Optimize using the appropriate optimization strategy
    opt_model_path = optimizer.optimize(save_dir="all-MiniLM-L6-v2_auto_opt_O2", optimization_config=AutoOptimizationConfig.O2())

    # Load the optimized ONNX model
    ort_model = ORTModelForFeatureExtraction.from_pretrained(opt_model_path, provider="CUDAExecutionProvider")

    # Load the optimized ONNX model
    tokenizer = AutoTokenizer.from_pretrained(opt_model_path, model_max_length=512)
    onnx_setfit_model = OnnxSetFitModel(ort_model, tokenizer, model.model_head)

    # Perform inference
    onnx_setfit_model(test_dataset["text"][:2])

    pb = OnnxPerformanceBenchmark(
        onnx_setfit_model,
        test_dataset,
        "all-MiniLM-L6-v2 (optimum ONNX)",
        model_path="all-MiniLM-L6-v2_auto_opt_O2/model_optimized.onnx",
    )

    perf_metrics.update(pb.run_benchmark())
    plot_metrics(perf_metrics)

    print(f"Speedup: {perf_metrics['all-MiniLM-L6-v2 (PyTorch)']['time_avg_ms'] / perf_metrics['all-MiniLM-L6-v2 (optimum ONNX)']['time_avg_ms']:.2f}x")

    output_path = "sklearn_model.onnx"
    export_onnx(model.model_body,
        model.model_head,
        opset=12,
        output_path=output_path)

    # Load the ONNX model
    onnx_model_path = output_path
    session = onnxruntime.InferenceSession(onnx_model_path, providers=['CUDAExecutionProvider'])

    # Check if CUDA execution provider is available
    providers = session.get_providers()
    print("providers:", providers)

    # Load the optimized ONNX model exported with export_onnx
    tokenizer_v2 = AutoTokenizer.from_pretrained("setfit-test-model-example", model_max_length=512)
    onnx_setfit_model_v2 = OnnxSetFitModelV2(session, tokenizer_v2, model.model_head)
    #print("model.model_body.max_seq_length:", model.model_body.max_seq_length)
    #onnx_setfit_model_v2 = OnnxSetFitModelV2(session, model.model_body.tokenizer, model.model_head)

    pb = OnnxPerformanceBenchmark(
        onnx_setfit_model_v2,
        test_dataset,
        "all-MiniLM-L6-v2 (setfit ONNX)",
        model_path=output_path,
    )

    perf_metrics.update(pb.run_benchmark())
    plot_metrics(perf_metrics)

    print(f"Speedup: {perf_metrics['all-MiniLM-L6-v2 (PyTorch)']['time_avg_ms'] / perf_metrics['all-MiniLM-L6-v2 (setfit ONNX)']['time_avg_ms']:.2f}x")

if __name__ == "__main__":
    main()

optimum version: 1.19.2 onnx: 1.16.1 onnxruntime-gpu: 1.16.0 transformers: 4.40.2 setfit: 1.0.3 torch: 2.2.1+cu121

Who can help?

@JingyaHuang, @echarlaix

Information

Tasks

Reproduction (minimal, reproducible, runnable)

code is above.

Expected behavior

optimum performs with same speed

geraldstanje commented 5 months ago

@MosheWasserb any idea about ^^