fpgaminer / joytag

The JoyTag Image Tagging Model
Apache License 2.0
400 stars 25 forks source link

ONNX format model? #5

Open dfl opened 8 months ago

dfl commented 8 months ago

thank you for this model! May I ask if you could possibly provide an ONNX format for the model? I want to try and use it in extensions for stable diffusion (e.g. comfyUI) and have been unable to figure out how to convert it.

dfl commented 8 months ago

this is as far as I got so far... python -m optimum.exporters.onnx --model joytag/ onnx/ --library transformers --task image-classification --framework tf I had to add "model_type": "vit" to the config.json now it complains about missing metadata. Perhaps I am approaching it wrong, I am still learning about all this stuff.

Loading PyTorch model in TensorFlow before exporting. Traceback (most recent call last): File "/opt/homebrew/Caskroom/miniconda/base/lib/python3.11/site-packages/optimum/exporters/tasks.py", line 1822, in get_model_from_task model = model_class.from_pretrained(model_name_or_path, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/Caskroom/miniconda/base/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 566, in from_pretrained return model_class.from_pretrained( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/homebrew/Caskroom/miniconda/base/lib/python3.11/site-packages/transformers/modeling_tf_utils.py", line 2879, in from_pretrained raise OSError( OSError: The safetensors archive passed at joytag/model.safetensors does not contain the valid metadata. Make sure you save your model with the save_pretrained method.

fpgaminer commented 7 months ago

I tried to export an ONNX the other day, but it looks like PyTorch's ONNX support is still very alpha at the moment. I'll try again if I get time, but no guarantees for now.

SmilingWolf commented 7 months ago

Converting through dynamo is a royal PITA (remove/comment out sdpa context managers, replace torch.nn.functional.scaled_dot_product_attention with the equivalent python code from the docs), but the "legacy" onnx exporter seems to work fairly well.

The graph isn't the prettiest and I haven't checked the output for correctness, but this should do the trick:

import torch
from Models import VisionModel

model = VisionModel.load_model("../joytag_weights/")

torch_input = torch.randn(1, 3, 448, 448)
torch.onnx.export(
    model,
    ({"image": torch_input}, {}),
    "joytag.onnx",
    export_params=True,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)
Particle1904 commented 6 months ago

Converting through dynamo is a royal PITA (remove/comment out sdpa context managers, replace torch.nn.functional.scaled_dot_product_attention with the equivalent python code from the docs), but the "legacy" onnx exporter seems to work fairly well.

The graph isn't the prettiest and I haven't checked the output for correctness, but this should do the trick:

import torch
from Models import VisionModel

model = VisionModel.load_model("../joytag_weights/")

torch_input = torch.randn(1, 3, 448, 448)
torch.onnx.export(
    model,
    ({"image": torch_input}, {}),
    "joytag.onnx",
    export_params=True,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)

Can you please be a bit more specific about how to do this? I can't figure out how to do it

Particle1904 commented 6 months ago

Ok, the problem was Torch version. I had 2.0.0 in my machine since I can only run DirectML stuff, upgrading to Torch 2.2.0 fixed the problem and the code executed successfully.

fpgaminer commented 6 months ago

https://huggingface.co/fancyfeast/joytag/blob/main/model.onnx

Thank you, SmilingWolf, that code worked on PyTorch 2.2.0. I just double checked the ONNX model's outputs to make sure it was working correctly.

I'll get some example usage code up and then close this issue.

Particle1904 commented 6 months ago

I have a working implementation of JoyTag up and running in my tools with C# and the OnnxRuntime. There is a small variation when compared to the Python implementation with the Safetensors model; but that seems to be the case for any model. The Onnx runtime usually get 1 or 2 extra tags (and its not even false positives) when running inference with the same threshold.