microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
MIT License
14.06k stars 2.83k forks source link

Error converting Microsoft Phi3 model to ONNX using Python and Transformers #21518

Closed junssashu closed 1 month ago

junssashu commented 1 month ago

Describe the issue


I encountered an error while attempting to convert a Microsoft Phi3 model to ONNX format using Python and the Transformers library. The conversion process fails with a KeyError indicating that the Phi3 model is not supported.

Aditionnal Context

I followed the standard procedure for model conversion using the Transformers library. However, it appears that the Phi3 model is not listed among the supported models for conversion. I am looking for guidance on how to proceed with this conversion or potential workarounds.


Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/juns/my_venv/lib/python3.12/site-packages/transformers/onnx/", line 242, in <module>
  File "/home/juns/my_venv/lib/python3.12/site-packages/transformers/onnx/", line 234, in main
  File "/home/juns/my_venv/lib/python3.12/site-packages/transformers/onnx/", line 79, in export_with_transformers
    model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
  File "/home/juns/my_venv/lib/python3.12/site-packages/transformers/onnx/", line 728, in check_supported_model_or_raise
    model_features = FeaturesManager.get_supported_features_for_model_type(model_type, model_name=model_name)
  File "/home/juns/my_venv/lib/python3.12/site-packages/transformers/onnx/", line 575, in get_supported_features_for_model_type
    raise KeyError(
KeyError: "phi3 is not supported yet. Only ['albert', 'bart', 'beit', 'bert', 'big-bird', 'bigbird-pegasus', 'blenderbot', 'blenderbot-small', 'bloom', 'camembert', 'clip', 'codegen', 'convbert', 'convnext', 'data2vec-text', 'data2vec-vision', 'deberta', 'deberta-v2', 'deit', 'detr', 'distilbert', 'electra', 'flaubert', 'gpt2', 'gptj', 'gpt-neo', 'groupvit', 'ibert', 'imagegpt', 'layoutlm', 'layoutlmv3', 'levit', 'longt5', 'longformer', 'marian', 'mbart', 'mobilebert', 'mobilenet-v1', 'mobilenet-v2', 'mobilevit', 'mt5', 'm2m-100', 'owlvit', 'perceiver', 'poolformer', 'rembert', 'resnet', 'roberta', 'roformer', 'segformer', 'squeezebert', 'swin', 't5', 'vision-encoder-decoder', 'vit', 'whisper', 'xlm', 'xlm-roberta', 'yolos'] are supported. If you want to support phi3 please propose a PR or open up an issue."

Expected Behavior

Successful conversion of the Phi3 model to ONNX format.

Validating ONNX model...
        -[✓] ONNX model output names match reference model ({'last_hidden_state'})
        - Validating ONNX Model output "last_hidden_state":
                -[✓] (2, 8, 768) matches (2, 8, 768)
                -[✓] all values close (atol: 1e-05)
All good, model saved at: onnx/model.onnx


To reproduce

Steps to Reproduce the Behavior

  1. Install Python and required libraries (onnxruntime, transformers).
  2. Attempt to convert the Phi3 model using the transformers.onnx module.

Code Snippet

python -m transformers.onnx --model=microsoft/Phi-3-mini-4k-instruct onnx/

### Urgency

This issue is critical for my project, which relies on the conversion of the Phi3 model to ONNX format for deployment in a mobile application. Timely resolution of this issue is essential to meet project deadlines and ensure the application functions as intended.

### Platform

Other / Unknown

### OS Version

ubuntu 22.04

### ONNX Runtime Installation

Released Package

### Compiler Version (if 'Built from Source')

_No response_

### Package Name (if 'Released Package')


### ONNX Runtime Version or Commit ID


### ONNX Runtime API

Other / Unknown

### Architecture


### Execution Provider

Default CPU

### Execution Provider Library Version

_No response_
skottmckay commented 1 month ago

There's a converted onnx model available here:

junssashu commented 1 month ago

yeah thanks but here the error i get using that converted ONNX Phi3 Model while using onnxruntime flutter to initialize the model image_2024-07-26_115930447

import 'dart:io';

import 'package:flutter/services.dart';
import 'package:onnxruntime/onnxruntime.dart';
import 'package:path_provider/path_provider.dart';

class Phi3ChatModel {
  final OrtSessionOptions _sessionOptions;
  late OrtSession _session;

  Phi3ChatModel() : _sessionOptions = OrtSessionOptions() {

  Future<void> initModel() async {
    print("-----------------------------------------initiation start -");
    print("-----------------------------------------initiation step 1 -");
    print("-----------------------------------------initiation step 2 -");
    print("-----------------------------------------initiation step 3 -");

    const assetFileName = 'assets/models/phi3_cpu.onnx';
    final rawAssetFile = await rootBundle.load(assetFileName);
    final bytes = rawAssetFile.buffer.asUint8List();
    print("-----------------------------------------initiation step 4 -");
    _session = OrtSession.fromBuffer(bytes, _sessionOptions);

    print("-----------------------------------------initiation step end -");

  Future<String> predict(String inputData) async {
    final inputTensor = OrtValueTensor.createTensorWithDataList([inputData], [1]);
    final inputs = {'input': inputTensor};
    final outputs = await _session.runAsync(OrtRunOptions(), inputs);

    final response = outputs?[0]?.value as List<String>;
    outputs?.forEach((element) => element?.release());

    return response.first;

  void release() {
import 'package:flutter/material.dart';
import 'package:onnxruntime_example/features/phi3_chat_model.dart';

class ONNXChatScreen extends StatefulWidget {
  const ONNXChatScreen({super.key});

  _ONNXChatScreenState createState() => _ONNXChatScreenState();

class _ONNXChatScreenState extends State<ONNXChatScreen> {
  final TextEditingController _controller = TextEditingController();
  late Phi3ChatModel _chatbotModel;
  List<String> _messages = [];
  bool _isModelInitialized = false;

  void initState() {

  Future<void> _initializeModel() async {
    _chatbotModel = Phi3ChatModel();
    try {
      await _chatbotModel.initModel();
      setState(() {
        _isModelInitialized = true;
    } catch (e) {
      print('Erreur lors de l\'initialisation du modèle : $e');

  void dispose() {

  void _sendMessage() async {
    final message = _controller.text;
    if (message.isEmpty) return;

    try {
      final response = await _chatbotModel.predict(message);
      setState(() {
        _messages.add('You: $message');
        _messages.add('Bot: $response');
    } catch (e) {
      print('Erreur lors de l\'envoi du message : $e');

  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(title: const Text('ONNX Chatbot')),
      body: !_isModelInitialized
          ? const Center(child: CircularProgressIndicator())
          : Column(
        children: <Widget>[
            child: ListView.builder(
              itemCount: _messages.length,
              itemBuilder: (context, index) => ListTile(
                title: Text(_messages[index]),
            padding: const EdgeInsets.all(8.0),
            child: Row(
              children: <Widget>[
                  child: TextField(controller: _controller),
                  icon: const Icon(Icons.send),
                  onPressed: _sendMessage,

here i downloaded the file at and renamed it in my flutter assets, image_2024-07-26_120642831

what could i have done bad please?

junssashu commented 1 month ago

There's a converted onnx model available here:

thanks for the quick reply :)

kunal-vaishnavi commented 1 month ago

Here are some alternative options to export Phi-3 mini to ONNX.

ONNX Runtime

You can use ONNX Runtime's convert_to_onnx tool for large-language models (LLMs) to convert, optimize, and/or quantize in one command. You can look at this README for more information, which uses LLaMA-2 as the example LLM.

python -m onnxruntime.transformers.models.llama.convert_to_onnx -m microsoft/Phi-3-mini-4k-instruct --output ./phi3_mini_4k --precision fp32 --execution_provider cpu

Hugging Face's Optimum

Instead of using transformers.onnx, you can use Hugging Face's Optimum to export Phi-3 mini via your terminal

optimum-cli export onnx --model microsoft/Phi-3-mini-4k-instruct ./phi3_mini_4k

or you can use a simple Python script

from optimum.onnxruntime import ORTModelForCausalLM

model_name = "microsoft/Phi-3-mini-4k-instruct"
cache_dir = "./cache_dir"

model = ORTModelForCausalLM.from_pretrained(model_name, export=True, cache_dir=cache_dir)
MaanavD commented 1 month ago

@junssashu Let me know if Kunal's solution works for you! If not, feel free to re-ping on this issue :)

junssashu commented 1 month ago

@junssashu Let me know if Kunal's solution works for you! If not, feel free to re-ping on this issue :)

Sorry for the time it took to respond. I've just abandoned that method. I switched to the GPT-2 Mini model, which actually loads successfully in my app. My next issue is how to infer on that loaded model to get a logical output for a chatbot app.

from now on i'll mark the issue as close but i'm trying the solution

here is my code to loads and use the model i don't know what i'm doing bad if someone can be of any helps i'll be gratefull

import 'dart:io';
import 'dart:math';
import 'dart:typed_data';
import 'package:flutter/services.dart';
import 'package:onnxruntime/onnxruntime.dart';
import 'package:path_provider/path_provider.dart';
import 'dart:convert';
import 'package:flutter/services.dart' show rootBundle;

class Phi3ChatModel {

  late OrtSession _session;
  late OrtSessionOptions _sessionOptions;
  late Map<String, dynamic> _tokenizerConfig;
  late Map<String, dynamic> _generationConfig;
  late Map<String, dynamic> _modelConfig;
  late Map<String, dynamic> _tokenizer;
  late Map<String, dynamic> _vocab;
  late Map<String, dynamic> _specialTokensMap;
  late int _vocabSize;
  late String _bosToken;
  late String _eosToken;
  late String _unkToken;
  late int _maxLength;

  Phi3ChatModel() : _sessionOptions = OrtSessionOptions() {

  Future<void> initModel() async {
    try {
      print("Initializing GPT-2 model...");
      _sessionOptions = OrtSessionOptions()

      final appDocDir = await getApplicationDocumentsDirectory();

      // Load ONNX model
      final modelFile = File('${appDocDir.path}/decoder_model.onnx');
      if (!await modelFile.exists()) {
        final data = await rootBundle.load('assets/models/gpt2/decoder_model.onnx');
        await modelFile.writeAsBytes(data.buffer.asUint8List());
      _session = OrtSession.fromFile(modelFile, _sessionOptions);

      // Load configuration files
      _tokenizerConfig = await _loadJsonFile('assets/models/gpt2/tokenizer_config.json');
      _generationConfig = await _loadJsonFile('assets/models/gpt2/generation_config.json');
      _modelConfig = await _loadJsonFile('assets/models/gpt2/config.json');
      _tokenizer = await _loadJsonFile('assets/models/gpt2/tokenizer.json');
      _specialTokensMap = await _loadJsonFile('assets/models/gpt2/special_tokens_map.json');
      _vocab = await _loadJsonFile('assets/models/gpt2/vocab.json');

      // Initialize model parameters
      _vocabSize = _modelConfig['vocab_size'];
      _bosToken = _specialTokensMap['bos_token'];
      _eosToken = _specialTokensMap['eos_token'];
      _unkToken = _specialTokensMap['unk_token'];
      _maxLength = _modelConfig['n_positions'];

      print("GPT-2 model initialized successfully.");
    } catch (e) {
      print("Error initializing GPT-2 model: $e");

  Future<Map<String, dynamic>> _loadJsonFile(String path) async {
    try {
      print("Loading JSON file $path...");
      String jsonString = await rootBundle.loadString(path);
      print("JSON file $path loaded successfully.");
      return json.decode(jsonString);
    } catch (e) {
      print("Error loading JSON file $path: $e");

  List<int> encode(String text) {
    print("Encoding text: \"$text\"");
    List<int> tokens = [];
    for (String word in text.split(' ')) {
      if (_tokenizer['model']['vocab'].containsKey(word)) {
      } else {
        // TODO: handle unknown words
    print("Encoded text: $tokens");
    return tokens;

  String decode(List<int> tokens) {
    print("Decoding tokens: $tokens");
    String text = '';
    for (int token in tokens) {
      if (_tokenizer['model']['vocab'].containsValue(token)) {
        String word = _tokenizer['model']['vocab'].keys.firstWhere((key) => _tokenizer['model']['vocab'][key] == token);
        text += '$word ';
      } else {
        // TODO: handle unknown tokens
    print("Decoded text: \"$text\"");
    return text.trim();

  void processOutputTokens(Object outputTokensObject, List<int> generatedTokens) {
    if (outputTokensObject is List) {
      if (outputTokensObject[0] is List) {
        for (var item in outputTokensObject) {
          processOutputTokens(item, generatedTokens);
      if (outputTokensObject[0] is double) {
        var probs = _softmax(outputTokensObject as List<double>);
        var sample = _sampleFromProbs(probs);

  Future<String> generateText(String prompt, {int maxNewTokens = 50}) async {
    try {
      print("Generating text with prompt: \"$prompt\"...");

      // Encode the prompt
      List<int> inputIds = encode(prompt);
      inputIds.insert(0, _modelConfig['bos_token_id']); // Beginning of sequence token

      // Initialize an empty list to store the generated tokens
      List<int> generatedTokens = [];

      var inputTensor = OrtValueTensor.createTensorWithDataList(
        [1, inputIds.length],

      // Create attention mask (all ones since there is no padding)
      var attentionMask = OrtValueTensor.createTensorWithDataList(
        Int64List.fromList(List.filled(inputIds.length, 1)),
        [1, inputIds.length],

      var ortInput = {
        'input_ids': inputTensor,
        'attention_mask': attentionMask

      final outputs =
        OrtRunOptions( ),

      var out1 = outputs[0]?.value as List;

      processOutputTokens(out1, generatedTokens);


      // Decode the generated tokens
      String result = decode(generatedTokens);
      print("Generated text: \"$result\"");
      return result;
    } catch (e) {
      print("Error generating text: $e");
      return "Error: Unable to generate text.";

// Helper function to calculate softmax
  List<double> _softmax(List<double> logits) {
    print("Calculating softmax for logits: $logits");
    double maxLogit = logits.reduce((a, b) => a > b ? a : b);
    List<double> expLogits = => exp(logit - maxLogit)).toList();
    double sum = expLogits.reduce((a, b) => a + b);
    List<double> result = => expLogit / sum).toList();
    print("Softmax result: $result");
    return result;

// Helper function to sample from probabilities
  int _sampleFromProbs(List<double> probs) {
    print("Sampling from probabilities: $probs");
    Random rand = Random();
    double cumulativeProb = 0.0;
    for (int i = 0; i < probs.length; i++) {
      cumulativeProb += probs[i];
      if (rand.nextDouble() < cumulativeProb) {
        int result = i;
        print("Sampled token: $result");
        return result;
    int result = probs.length - 1; // Fallback to last token
    print("Sampled token: $result");
    return result;

  void release() {
    print("Releasing resources...");
    print("Resources released.");

i'm using onnx runtime flutter package