Open pcuenca opened 1 year ago
Thank you for the super detailed bug report! Yes I can confirm that I can reproduce this issue and the dtype
is ignored when lowering torch.arange
.
I will mark this issue as triaged and add it to fix pipeline, which will depend on team's task prioritization. Thanks!
🐞Describing the bug
torch.arange(value, dtype=torch.float)
yields anint
tensor instead offloat
, as demonstrated in the code snippet below, ifvalue
is an integer. This produces conversion errors in several transformers models where sinusoidal positions can be created like this: https://github.com/huggingface/transformers/blob/b61d5b47f640308068139561f673765b2af39874/src/transformers/models/gptj/modeling_gptj.py#L60Stack Trace
Reveal
``` Traceback (most recent call last): File "/opt/homebrew/Caskroom/miniforge/base/envs/sdcoreml/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/opt/homebrew/Caskroom/miniforge/base/envs/sdcoreml/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/Users/pedro/.vscode/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, inTo Reproduce
Observe that the previous example works if any of the following modifications are performed:
arange
:return torch.einsum("i,j->ij", torch.arange(5.0, dtype=torch.float), x)
float
anyway:return torch.einsum("i,j->ij", torch.arange(5, dtype=torch.float).float(), x)
In addition, the following also works:
System environment (please complete the following information):
Additional context
In the presence of the bug, registering this custom op and setting a breakpoint confirms that the type of
a
is integer (and thereforedtype
was ignored):