microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.06k stars 2.83k forks source link

Error converting Microsoft Phi3 model to ONNX using Python and Transformers #21518

Closed junssashu closed 1 month ago

junssashu commented 1 month ago

Describe the issue

Context

I encountered an error while attempting to convert a Microsoft Phi3 model to ONNX format using Python and the Transformers library. The conversion process fails with a KeyError indicating that the Phi3 model is not supported.

Aditionnal Context

I followed the standard procedure for model conversion using the Transformers library. However, it appears that the Phi3 model is not listed among the supported models for conversion. I am looking for guidance on how to proceed with this conversion or potential workarounds.

Error

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/juns/my_venv/lib/python3.12/site-packages/transformers/onnx/__main__.py", line 242, in <module>
    main()
  File "/home/juns/my_venv/lib/python3.12/site-packages/transformers/onnx/__main__.py", line 234, in main
    export_with_transformers(args)
  File "/home/juns/my_venv/lib/python3.12/site-packages/transformers/onnx/__main__.py", line 79, in export_with_transformers
    model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/juns/my_venv/lib/python3.12/site-packages/transformers/onnx/features.py", line 728, in check_supported_model_or_raise
    model_features = FeaturesManager.get_supported_features_for_model_type(model_type, model_name=model_name)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/juns/my_venv/lib/python3.12/site-packages/transformers/onnx/features.py", line 575, in get_supported_features_for_model_type
    raise KeyError(
KeyError: "phi3 is not supported yet. Only ['albert', 'bart', 'beit', 'bert', 'big-bird', 'bigbird-pegasus', 'blenderbot', 'blenderbot-small', 'bloom', 'camembert', 'clip', 'codegen', 'convbert', 'convnext', 'data2vec-text', 'data2vec-vision', 'deberta', 'deberta-v2', 'deit', 'detr', 'distilbert', 'electra', 'flaubert', 'gpt2', 'gptj', 'gpt-neo', 'groupvit', 'ibert', 'imagegpt', 'layoutlm', 'layoutlmv3', 'levit', 'longt5', 'longformer', 'marian', 'mbart', 'mobilebert', 'mobilenet-v1', 'mobilenet-v2', 'mobilevit', 'mt5', 'm2m-100', 'owlvit', 'perceiver', 'poolformer', 'rembert', 'resnet', 'roberta', 'roformer', 'segformer', 'squeezebert', 'swin', 't5', 'vision-encoder-decoder', 'vit', 'whisper', 'xlm', 'xlm-roberta', 'yolos'] are supported. If you want to support phi3 please propose a PR or open up an issue."

Expected Behavior

Successful conversion of the Phi3 model to ONNX format.

Validating ONNX model...
        -[✓] ONNX model output names match reference model ({'last_hidden_state'})
        - Validating ONNX Model output "last_hidden_state":
                -[✓] (2, 8, 768) matches (2, 8, 768)
                -[✓] all values close (atol: 1e-05)
All good, model saved at: onnx/model.onnx

Environment

To reproduce

Steps to Reproduce the Behavior

  1. Install Python and required libraries (onnxruntime, transformers).
  2. Attempt to convert the Phi3 model using the transformers.onnx module.

Code Snippet


python -m transformers.onnx --model=microsoft/Phi-3-mini-4k-instruct onnx/

### Urgency

This issue is critical for my project, which relies on the conversion of the Phi3 model to ONNX format for deployment in a mobile application. Timely resolution of this issue is essential to meet project deadlines and ensure the application functions as intended.

### Platform

Other / Unknown

### OS Version

ubuntu 22.04

### ONNX Runtime Installation

Released Package

### Compiler Version (if 'Built from Source')

_No response_

### Package Name (if 'Released Package')

Microsoft.ML.OnnxRuntime

### ONNX Runtime Version or Commit ID

7801d794

### ONNX Runtime API

Other / Unknown

### Architecture

X64

### Execution Provider

Default CPU

### Execution Provider Library Version

_No response_
skottmckay commented 1 month ago

There's a converted onnx model available here: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx

junssashu commented 1 month ago

yeah thanks but here the error i get using that converted ONNX Phi3 Model while using onnxruntime flutter to initialize the model https://pub.dev/packages/onnxruntime image_2024-07-26_115930447

import 'dart:io';

import 'package:flutter/services.dart';
import 'package:onnxruntime/onnxruntime.dart';
import 'package:path_provider/path_provider.dart';

class Phi3ChatModel {
  final OrtSessionOptions _sessionOptions;
  late OrtSession _session;

  Phi3ChatModel() : _sessionOptions = OrtSessionOptions() {
    OrtEnv.instance.init();
  }

  Future<void> initModel() async {
    print("-----------------------------------------initiation start -");
    _sessionOptions.setInterOpNumThreads(1);
    print("-----------------------------------------initiation step 1 -");
    _sessionOptions.setIntraOpNumThreads(1);
    print("-----------------------------------------initiation step 2 -");
    _sessionOptions.setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
    print("-----------------------------------------initiation step 3 -");

    const assetFileName = 'assets/models/phi3_cpu.onnx';
    final rawAssetFile = await rootBundle.load(assetFileName);
    final bytes = rawAssetFile.buffer.asUint8List();
    print("-----------------------------------------initiation step 4 -");
    _session = OrtSession.fromBuffer(bytes, _sessionOptions);

    print("-----------------------------------------initiation step end -");
  }

  Future<String> predict(String inputData) async {
    final inputTensor = OrtValueTensor.createTensorWithDataList([inputData], [1]);
    final inputs = {'input': inputTensor};
    final outputs = await _session.runAsync(OrtRunOptions(), inputs);
    inputTensor.release();

    final response = outputs?[0]?.value as List<String>;
    outputs?.forEach((element) => element?.release());

    return response.first;
  }

  void release() {
    _sessionOptions.release();
    _session.release();
    OrtEnv.instance.release();
  }
}
import 'package:flutter/material.dart';
import 'package:onnxruntime_example/features/phi3_chat_model.dart';

class ONNXChatScreen extends StatefulWidget {
  const ONNXChatScreen({super.key});

  @override
  _ONNXChatScreenState createState() => _ONNXChatScreenState();
}

class _ONNXChatScreenState extends State<ONNXChatScreen> {
  final TextEditingController _controller = TextEditingController();
  late Phi3ChatModel _chatbotModel;
  List<String> _messages = [];
  bool _isModelInitialized = false;

  @override
  void initState() {
    super.initState();
    _initializeModel();
  }

  Future<void> _initializeModel() async {
    _chatbotModel = Phi3ChatModel();
    try {
      await _chatbotModel.initModel();
      setState(() {
        _isModelInitialized = true;
      });
    } catch (e) {
      print('Erreur lors de l\'initialisation du modèle : $e');
    }
  }

  @override
  void dispose() {
    _chatbotModel.release();
    super.dispose();
  }

  void _sendMessage() async {
    final message = _controller.text;
    if (message.isEmpty) return;

    try {
      final response = await _chatbotModel.predict(message);
      setState(() {
        _messages.add('You: $message');
        _messages.add('Bot: $response');
        _controller.clear();
      });
    } catch (e) {
      print('Erreur lors de l\'envoi du message : $e');
    }
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(title: const Text('ONNX Chatbot')),
      body: !_isModelInitialized
          ? const Center(child: CircularProgressIndicator())
          : Column(
        children: <Widget>[
          Expanded(
            child: ListView.builder(
              itemCount: _messages.length,
              itemBuilder: (context, index) => ListTile(
                title: Text(_messages[index]),
              ),
            ),
          ),
          Padding(
            padding: const EdgeInsets.all(8.0),
            child: Row(
              children: <Widget>[
                Expanded(
                  child: TextField(controller: _controller),
                ),
                IconButton(
                  icon: const Icon(Icons.send),
                  onPressed: _sendMessage,
                ),
              ],
            ),
          ),
        ],
      ),
    );
  }
}

here i downloaded the file at https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx?download=true and renamed it in my flutter assets, image_2024-07-26_120642831

what could i have done bad please?

junssashu commented 1 month ago

There's a converted onnx model available here: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx

thanks for the quick reply :)

kunal-vaishnavi commented 1 month ago

Here are some alternative options to export Phi-3 mini to ONNX.

ONNX Runtime

You can use ONNX Runtime's convert_to_onnx tool for large-language models (LLMs) to convert, optimize, and/or quantize in one command. You can look at this README for more information, which uses LLaMA-2 as the example LLM.

python -m onnxruntime.transformers.models.llama.convert_to_onnx -m microsoft/Phi-3-mini-4k-instruct --output ./phi3_mini_4k --precision fp32 --execution_provider cpu

Hugging Face's Optimum

Instead of using transformers.onnx, you can use Hugging Face's Optimum to export Phi-3 mini via your terminal

optimum-cli export onnx --model microsoft/Phi-3-mini-4k-instruct ./phi3_mini_4k

or you can use a simple Python script

from optimum.onnxruntime import ORTModelForCausalLM

model_name = "microsoft/Phi-3-mini-4k-instruct"
cache_dir = "./cache_dir"

model = ORTModelForCausalLM.from_pretrained(model_name, export=True, cache_dir=cache_dir)
model.save_pretrained("phi3_onnx/")
MaanavD commented 1 month ago

@junssashu Let me know if Kunal's solution works for you! If not, feel free to re-ping on this issue :)

junssashu commented 1 month ago

@junssashu Let me know if Kunal's solution works for you! If not, feel free to re-ping on this issue :)

Sorry for the time it took to respond. I've just abandoned that method. I switched to the GPT-2 Mini model, which actually loads successfully in my app. My next issue is how to infer on that loaded model to get a logical output for a chatbot app.

from now on i'll mark the issue as close but i'm trying the solution

here is my code to loads and use the model i don't know what i'm doing bad if someone can be of any helps i'll be gratefull

import 'dart:io';
import 'dart:math';
import 'dart:typed_data';
import 'package:flutter/services.dart';
import 'package:onnxruntime/onnxruntime.dart';
import 'package:path_provider/path_provider.dart';
import 'dart:convert';
import 'package:flutter/services.dart' show rootBundle;

class Phi3ChatModel {

  late OrtSession _session;
  late OrtSessionOptions _sessionOptions;
  late Map<String, dynamic> _tokenizerConfig;
  late Map<String, dynamic> _generationConfig;
  late Map<String, dynamic> _modelConfig;
  late Map<String, dynamic> _tokenizer;
  late Map<String, dynamic> _vocab;
  late Map<String, dynamic> _specialTokensMap;
  late int _vocabSize;
  late String _bosToken;
  late String _eosToken;
  late String _unkToken;
  late int _maxLength;

  Phi3ChatModel() : _sessionOptions = OrtSessionOptions() {
    OrtEnv.instance.init();
  }

  Future<void> initModel() async {
    try {
      print("Initializing GPT-2 model...");
      _sessionOptions = OrtSessionOptions()
        ..setInterOpNumThreads(1)
        ..setIntraOpNumThreads(1)
        ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);

      final appDocDir = await getApplicationDocumentsDirectory();

      // Load ONNX model
      final modelFile = File('${appDocDir.path}/decoder_model.onnx');
      if (!await modelFile.exists()) {
        final data = await rootBundle.load('assets/models/gpt2/decoder_model.onnx');
        await modelFile.writeAsBytes(data.buffer.asUint8List());
      }
      _session = OrtSession.fromFile(modelFile, _sessionOptions);

      // Load configuration files
      _tokenizerConfig = await _loadJsonFile('assets/models/gpt2/tokenizer_config.json');
      _generationConfig = await _loadJsonFile('assets/models/gpt2/generation_config.json');
      _modelConfig = await _loadJsonFile('assets/models/gpt2/config.json');
      _tokenizer = await _loadJsonFile('assets/models/gpt2/tokenizer.json');
      _specialTokensMap = await _loadJsonFile('assets/models/gpt2/special_tokens_map.json');
      _vocab = await _loadJsonFile('assets/models/gpt2/vocab.json');

      // Initialize model parameters
      _vocabSize = _modelConfig['vocab_size'];
      _bosToken = _specialTokensMap['bos_token'];
      _eosToken = _specialTokensMap['eos_token'];
      _unkToken = _specialTokensMap['unk_token'];
      _maxLength = _modelConfig['n_positions'];

      print("GPT-2 model initialized successfully.");
    } catch (e) {
      print("Error initializing GPT-2 model: $e");
      rethrow;
    }
  }

  Future<Map<String, dynamic>> _loadJsonFile(String path) async {
    try {
      print("Loading JSON file $path...");
      String jsonString = await rootBundle.loadString(path);
      print("JSON file $path loaded successfully.");
      return json.decode(jsonString);
    } catch (e) {
      print("Error loading JSON file $path: $e");
      rethrow;
    }
  }

  List<int> encode(String text) {
    print("Encoding text: \"$text\"");
    List<int> tokens = [];
    for (String word in text.split(' ')) {
      if (_tokenizer['model']['vocab'].containsKey(word)) {
        tokens.add(_tokenizer['model']['vocab'][word]);
      } else {
        // TODO: handle unknown words
      }
    }
    print("Encoded text: $tokens");
    return tokens;
  }

  String decode(List<int> tokens) {
    print("Decoding tokens: $tokens");
    String text = '';
    for (int token in tokens) {
      if (_tokenizer['model']['vocab'].containsValue(token)) {
        String word = _tokenizer['model']['vocab'].keys.firstWhere((key) => _tokenizer['model']['vocab'][key] == token);
        text += '$word ';
      } else {
        // TODO: handle unknown tokens
      }
    }
    print("Decoded text: \"$text\"");
    return text.trim();
  }

  void processOutputTokens(Object outputTokensObject, List<int> generatedTokens) {
    if (outputTokensObject is List) {
      if (outputTokensObject[0] is List) {
        for (var item in outputTokensObject) {
          processOutputTokens(item, generatedTokens);
        }
      }
      if (outputTokensObject[0] is double) {
        var probs = _softmax(outputTokensObject as List<double>);
        var sample = _sampleFromProbs(probs);
        generatedTokens.add(sample);
      }
    }
  }

  Future<String> generateText(String prompt, {int maxNewTokens = 50}) async {
    try {
      print("Generating text with prompt: \"$prompt\"...");

      // Encode the prompt
      List<int> inputIds = encode(prompt);
      inputIds.insert(0, _modelConfig['bos_token_id']); // Beginning of sequence token
      inputIds.add(_modelConfig['eos_token_id']);

      // Initialize an empty list to store the generated tokens
      List<int> generatedTokens = [];

      var inputTensor = OrtValueTensor.createTensorWithDataList(
        inputIds,
        [1, inputIds.length],
      );

      // Create attention mask (all ones since there is no padding)
      var attentionMask = OrtValueTensor.createTensorWithDataList(
        Int64List.fromList(List.filled(inputIds.length, 1)),
        [1, inputIds.length],
      );

      var ortInput = {
        'input_ids': inputTensor,
        'attention_mask': attentionMask
      };

      final outputs = _session.run(
        OrtRunOptions( ),
        ortInput
      );

      var out1 = outputs[0]?.value as List;

      processOutputTokens(out1, generatedTokens);

      inputTensor.release();
      attentionMask.release();

      // Decode the generated tokens
      String result = decode(generatedTokens);
      print("Generated text: \"$result\"");
      return result;
    } catch (e) {
      print("Error generating text: $e");
      return "Error: Unable to generate text.";
    }
  }

// Helper function to calculate softmax
  List<double> _softmax(List<double> logits) {
    print("Calculating softmax for logits: $logits");
    double maxLogit = logits.reduce((a, b) => a > b ? a : b);
    List<double> expLogits = logits.map((logit) => exp(logit - maxLogit)).toList();
    double sum = expLogits.reduce((a, b) => a + b);
    List<double> result = expLogits.map((expLogit) => expLogit / sum).toList();
    print("Softmax result: $result");
    return result;
  }

// Helper function to sample from probabilities
  int _sampleFromProbs(List<double> probs) {
    print("Sampling from probabilities: $probs");
    Random rand = Random();
    double cumulativeProb = 0.0;
    for (int i = 0; i < probs.length; i++) {
      cumulativeProb += probs[i];
      if (rand.nextDouble() < cumulativeProb) {
        int result = i;
        print("Sampled token: $result");
        return result;
      }
    }
    int result = probs.length - 1; // Fallback to last token
    print("Sampled token: $result");
    return result;
  }

  void release() {
    print("Releasing resources...");
    _sessionOptions.release();
    _session.release();
    OrtEnv.instance.release();
    print("Resources released.");
  }
}

i'm using onnx runtime flutter package