huggingface / optimum

🚀 Accelerate training and inference of 🤗 Transformers and 🤗 Diffusers with easy to use hardware optimization tools
https://huggingface.co/docs/optimum/main/
Apache License 2.0
2.54k stars 455 forks source link

Optional `subfolder` if model repository contains one ONNX model behind a subfolder #2008

Open tomaarsen opened 1 month ago

tomaarsen commented 1 month ago

Hello!

The Quirk

I've noticed some interesting behaviour, and I think there's a chance that it's unintended. Let's start with this snippet:

from optimum.onnxruntime import ORTModelForFeatureExtraction

model = ORTModelForFeatureExtraction.from_pretrained("BAAI/bge-small-en-v1.5")

Perhaps surprisingly, perhaps not, this fails:

Entry Not Found for url: https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main/model.onnx.

This file indeed does not exist, there is only a model.onnx in an onnx subfolder: https://huggingface.co/BAAI/bge-small-en-v1.5/tree/main

When the file_name is not specified, such as in the above snippet, then the from_pretrained call will try and infer it: https://github.com/huggingface/optimum/blob/8cb6832a2797f54ec1221ff5014a81d961016b6b/optimum/onnxruntime/modeling_ort.py#L509-L529

In our case, we take the else branch (as the model is remote): https://github.com/huggingface/optimum/blob/8cb6832a2797f54ec1221ff5014a81d961016b6b/optimum/onnxruntime/modeling_ort.py#L513-L519

Here, repo_files is:

['.gitattributes', '1_Pooling/config.json', 'README.md', 'config.json', 'config_sentence_transformers.json', 'model.safetensors', 'modules.json', 'onnx/model.onnx', 'pytorch_model.bin', 'sentence_bert_config.json', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.txt']

which leads to a onnx_files of:

[WindowsPath('onnx/model.onnx')]

This bypasses the if len(...) == 0 and if len(...) > 1 errors, and sets file_name as onnx_files[0].name, i.e. "model.onnx".

This then fails when actually loading the model, because there is no "model.onnx" in the root of the repository, whereas we can be quite sure that the user intended to load this ONNX model. Instead, we currently require that the user specifies either subfolder="onnx" or file_name="onnx/model.onnx".

Potential Fixes

Fix A

        if file_name is None:
            if model_path.is_dir():
                onnx_files = list(model_path.glob("*.onnx"))
            else:
                repo_files, _ = TasksManager.get_model_files(
                    model_id, revision=revision, cache_dir=cache_dir, token=token
                )
                repo_files = map(Path, repo_files)

                pattern = "*.onnx" if subfolder == "" else f"{subfolder}/*.onnx"
                onnx_files = [p for p in repo_files if p.match(pattern)]

            if len(onnx_files) == 0:
                raise FileNotFoundError(f"Could not find any ONNX model file in {model_path}")
            elif len(onnx_files) > 1:
                raise RuntimeError(
                    f"Too many ONNX model files were found in {model_path}, specify which one to load by using the "
                    "file_name argument."
                )
            else:
-               file_name = onnx_files[0].name
+               file_name = onnx_files[0].relative_to(subfolder).as_posix()

This would work in the normal cases as well as when the only ONNX file is in a subfolder. The relative_to means that it'll also work if a subfolder was provided. There might still be some missed edge cases.

The downside is that this results in the following warning:

The ONNX file onnx/model.onnx is not a regular name used in optimum.onnxruntime, the ORTModel might not behave as expected.

Fix B

        if file_name is None:
            if model_path.is_dir():
                onnx_files = list(model_path.glob("*.onnx"))
            else:
                repo_files, _ = TasksManager.get_model_files(
                    model_id, revision=revision, cache_dir=cache_dir, token=token
                )
                repo_files = map(Path, repo_files)

                pattern = "*.onnx" if subfolder == "" else f"{subfolder}/*.onnx"
                onnx_files = [p for p in repo_files if p.match(pattern)]

            if len(onnx_files) == 0:
                raise FileNotFoundError(f"Could not find any ONNX model file in {model_path}")
            elif len(onnx_files) > 1:
                raise RuntimeError(
                    f"Too many ONNX model files were found in {model_path}, specify which one to load by using the "
                    "file_name argument."
                )
            else:
                file_name = onnx_files[0].name
+               subfolder = onnx_files[0].parent.as_posix()

This overrides/sets the subfolder so that we load e.g. model.onnx from whatever subfolder it exists in. There might still be some missed edge cases.

Will you consider a fix for this quirk?

IlyasMoutawwakil commented 1 month ago

This bypasses the if len(...) == 0 and if len(...) > 1 errors, and sets file_name as onnx_files[0].name, i.e. "model.onnx".

You mean 'onnx/model.onnx' ? what about adding subfolder=subfolder in TasksManager.get_model_files( and removing the subfolder logic from modelign_ort.py

IlyasMoutawwakil commented 1 month ago

@tomaarsen does the suggestion work for you ?

tomaarsen commented 1 month ago

@IlyasMoutawwakil Apologies for the delay, I totally missed the notification from your previous message.

You mean 'onnx/model.onnx' ?

No, it sets it as model.onnx. Path("onnx/model.onnx").name is model.onnx as name only takes the last part.

You mean 'onnx/model.onnx' ? what about adding subfolder=subfolder in TasksManager.get_model_files( and removing the subfolder logic from modelign_ort.py

I think that might be fine as well.

I'm also curious about a potential unexpected crash for users. Imagine a scenario where a user is loading an ONNX model from a remote repository on HF. This repository has a model.onnx file, and that's the only ONNX file in the project. They're able to use this without specifying file_name. Then, at some point the model author decides to add an optimized ONNX model to their repository as well. Instead of just being "more options", the user's project stops working as a file_name is now required. Although it's a simple fix on the user's side, it's not ideal to have this risk of breakage.

Would you consider replacing the error with a warning instead, if a file named model.onnx exists? Additionally, the warning could actually print out which alternatives exist - that would be very useful. Otherwise, I will likely implement this myself in Sentence Transformers when I move forward with adding an ONNX backend.