huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.03k stars 26.3k forks source link

ONNX export may fail for Hiera due to `math.sqrt` and python type `int` casts #33181

Closed xenova closed 1 week ago

xenova commented 2 weeks ago

Similar to a previous issue (https://github.com/huggingface/transformers/pull/31311) we had with vision models. cc @merveenoyan @amyeroberts @NielsRogge

Offending lines:

https://github.com/huggingface/transformers/blob/5c1027bf09717f664b579e01cbb8ec3ef5aeb140/src/transformers/models/hiera/modeling_hiera.py#L340

https://github.com/huggingface/transformers/blob/5c1027bf09717f664b579e01cbb8ec3ef5aeb140/src/transformers/models/hiera/modeling_hiera.py#L344

We should probably abstract interpolate_pos_encoding, since this is reused across many different vision architectures.

qubvel commented 2 weeks ago

Hi @xenova, there are more issues with the current implementation of interpolate_pos_encoding

Maybe it's better to rewrite interpolation using size argument instead of scale to avoid errors described in the issue above, will it be compatible with ONNX export? WDYT?

xenova commented 2 weeks ago

Maybe it's better to rewrite interpolation using size argument instead of scale to avoid errors described in the issue above, will it be compatible with ONNX export? WDYT?

Yes I agree! Some other issues/PRs where this is discussed:

I'll try draft a PR for this soon.