apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.46k stars 648 forks source link

`torch.arange` ignores `dtype` when used statically #1852

Open pcuenca opened 1 year ago

pcuenca commented 1 year ago

🐞Describing the bug

torch.arange(value, dtype=torch.float) yields an int tensor instead of float, as demonstrated in the code snippet below, if value is an integer. This produces conversion errors in several transformers models where sinusoidal positions can be created like this: https://github.com/huggingface/transformers/blob/b61d5b47f640308068139561f673765b2af39874/src/transformers/models/gptj/modeling_gptj.py#L60

Stack Trace

Reveal ``` Traceback (most recent call last): File "/opt/homebrew/Caskroom/miniforge/base/envs/sdcoreml/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/opt/homebrew/Caskroom/miniforge/base/envs/sdcoreml/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/Users/pedro/.vscode/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in cli.main() File "/Users/pedro/.vscode/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main run() File "/Users/pedro/.vscode/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file runpy.run_path(target, run_name="__main__") File "/Users/pedro/.vscode/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path return _run_module_code(code, init_globals, run_name, File "/Users/pedro/.vscode/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code _run_code(code, mod_globals, init_globals, File "/Users/pedro/.vscode/extensions/ms-python.python-2023.6.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code exec(code, run_globals) File "/Users/pedro/code/hf/apple/exporters/einsum_test.py", line 20, in mlmodel = ct.convert( File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/_converters_entry.py", line 492, in convert mlmodel = mil_convert( File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/converter.py", line 188, in mil_convert return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs) File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/converter.py", line 212, in _mil_convert proto, mil_program = mil_convert_to_proto( File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/converter.py", line 285, in mil_convert_to_proto prog = frontend_converter(model, **kwargs) File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/converter.py", line 108, in __call__ return load(*args, **kwargs) File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 63, in load return _perform_torch_convert(converter, debug) File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/frontend/torch/load.py", line 102, in _perform_torch_convert prog = converter.convert() File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/frontend/torch/converter.py", line 284, in convert convert_nodes(self.context, self.graph) File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/frontend/torch/ops.py", line 88, in convert_nodes add_op(context, node) File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/frontend/torch/ops.py", line 1122, in einsum x = build_einsum_mil(a, b, equation, node.name) File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/frontend/_utils.py", line 171, in build_einsum_mil x = solve_generic_einsum(parsed_vectors, a_var, b_var, name) File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/frontend/_utils.py", line 390, in solve_generic_einsum ab = mb.matmul(x=a, y=b) File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/mil/ops/registry.py", line 182, in add_op return cls._add_op(op_cls_to_add, **kwargs) File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/mil/builder.py", line 166, in _add_op new_op = op_cls(**kwargs) File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/mil/operation.py", line 187, in __init__ self._validate_and_set_inputs(input_kv) File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/mil/operation.py", line 496, in _validate_and_set_inputs self.input_spec.validate_inputs(self.name, self.op_type, input_kvs) File "/Users/pedro/code/hf/apple/coremltools/coremltools/converters/mil/mil/input_type.py", line 137, in validate_inputs raise ValueError(msg) ValueError: In op, of type matmul, named matmul_0, the named input `y` must have the same data type as the named input `x`. However, y has dtype fp32 whereas x has dtype int32. ```

To Reproduce

import torch
import torch.nn as nn
import coremltools as ct

class MinimalRepro(nn.Module):        
    def forward(self, x):
        # Outer product
        return torch.einsum("i,j->ij", torch.arange(5, dtype=torch.float), x)

model = MinimalRepro()

x = torch.randn(3)
assert model(x).shape == (5, 3)

jitted_model = torch.jit.trace(model, x)
jitted_model.eval()
x_j = jitted_model(x)
assert (x_j - model(x)).abs().max() < 1e-4

mlmodel = ct.convert(
    jitted_model,
    inputs=[ct.TensorType(name="x", shape=x.shape)],
)

Observe that the previous example works if any of the following modifications are performed:

return torch.einsum("i,j->ij", torch.arange(5.0, dtype=torch.float), x)

return torch.einsum("i,j->ij", torch.arange(5, dtype=torch.float).float(), x)

In addition, the following also works:

import torch
import torch.nn as nn
import coremltools as ct

class MinimalRepro(nn.Module):
    def __init__(self):
        super().__init__()
        self.y = torch.randn(3)

    def forward(self, x):
        # Outer product
        return torch.einsum("i,j->ij", x, self.y)

model = MinimalRepro()

x = torch.arange(5, dtype=float)
assert model(x).shape == (5, 3)

jitted_model = torch.jit.trace(model, x)
jitted_model.eval()
x_j = jitted_model(x)
assert (x_j - model(x)).abs().max() < 1e-4

mlmodel = ct.convert(
    jitted_model,
    inputs=[ct.TensorType(name="x", shape=x.shape)],
)

System environment (please complete the following information):

Additional context

In the presence of the bug, registering this custom op and setting a breakpoint confirms that the type of a is integer (and therefore dtype was ignored):

from coremltools.converters.mil.frontend.torch.torch_op_registry import _TORCH_OPS_REGISTRY, register_torch_op

del _TORCH_OPS_REGISTRY["einsum"]

@register_torch_op
def einsum(context, node):
    from coremltools.converters.mil import Builder as mb
    from coremltools.converters.mil.frontend._utils import build_einsum_mil
    from coremltools.converters.mil.mil import types

    a = context[node.inputs[1]][0]
    b = context[node.inputs[1]][1]
    equation = context[node.inputs[0]].val
    equation = "".join(equation.split(" "))
    if equation == "i,j->ij" and types.is_int(a.dtype):
        a = mb.cast(x=a, dtype="fp32")
    x = build_einsum_mil(a, b, equation, node.name)

    context.add(x)
junpeiz commented 1 year ago

Thank you for the super detailed bug report! Yes I can confirm that I can reproduce this issue and the dtype is ignored when lowering torch.arange.

I will mark this issue as triaged and add it to fix pipeline, which will depend on team's task prioritization. Thanks!