tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
471 stars 74 forks source link

Argument mismatch while converting `arange` #11925

Closed jdh8 closed 2 months ago

jdh8 commented 2 months ago

Argument types mismatch while converting aten.arange to ttnn.arange:

self = FastOperation(python_fully_qualified_name='ttnn.arange', function=<ttnn._ttnn.operations.creation.arange_t object at 0...<function default_postprocess_golden_function_outputs at 0x7fcf98535940>, is_cpp_operation=True, is_experimental=False)
function_args = (4,), function_kwargs = {'device': <ttnn._ttnn.deprecated.device.Device object at 0x7fcf58757970>, 'end': 100, 'step': 3}

    def __call__(self, *function_args, **function_kwargs):
>       return self.function(*function_args, **function_kwargs)
E       TypeError: __call__(): incompatible function arguments. The following argument types are supported:
E           1. (self: ttnn._ttnn.operations.creation.arange_t, stop: int, dtype: ttnn._ttnn.deprecated.tensor.DataType = <DataType.BFLOAT16: 0>, device: Optional[ttnn._ttnn.deprecated.device.Device] = None, memory_config: ttnn._ttnn.deprecated.tensor.MemoryConfig = MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::DRAM,shard_spec=std::nullopt)) -> ttnn._ttnn.deprecated.tensor.Tensor
E           2. (self: ttnn._ttnn.operations.creation.arange_t, start: int, stop: int, step: int = 1, dtype: ttnn._ttnn.deprecated.tensor.DataType = <DataType.BFLOAT16: 0>, device: Optional[ttnn._ttnn.deprecated.device.Device] = None, memory_config: ttnn._ttnn.deprecated.tensor.MemoryConfig = MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::DRAM,shard_spec=std::nullopt)) -> ttnn._ttnn.deprecated.tensor.Tensor
E       
E       Invoked with: <ttnn._ttnn.operations.creation.arange_t object at 0x7fcf9895b230>, 4; kwargs: end=100, step=3, device=<ttnn._ttnn.deprecated.device.Device object at 0x7fcf58757970>
E       
E       Did you forget to `#include <pybind11/stl.h>`? Or <pybind11/complex.h>,
E       <pybind11/functional.h>, <pybind11/chrono.h>, etc. Some automatic
E       conversions are optional and require extra headers to be included
E       when compiling your pybind11 module.

../tt-metal/ttnn/ttnn/decorators.py:328: TypeError
------------------------------------------------------------------------------ Captured stdout teardown ------------------------------------------------------------------------------
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0
============================================================================== short test summary info ===============================================================================
FAILED tests/lowering/creation/test_arange.py::test_arange_start[input_shapes0] - TypeError: __call__(): incompatible function arguments. The following argument types are supported:
FAILED tests/lowering/creation/test_arange.py::test_arange_start_step[input_shapes0] - TypeError: __call__(): incompatible function arguments. The following argument types are supported:
============================================================================ 2 failed, 1 xfailed in 2.57s ============================================================================
                 Device | INFO     | Closing user mode device drivers
ayerofieiev-tt commented 2 months ago

Can you please include how you expect ttnn call looks like after the conversion?

jdh8 commented 2 months ago

Conversion for other ops later expands (kw)args for the underlying ttnn op.

https://github.com/tenstorrent/pytorch2.0_ttnn/blob/8c4bd2cf4e1f9abd183aa866e4c68f1264710778/torch_ttnn/passes/lowering/to_tt_pass.py#L231-L233

However, with the error message above, kwargs looks as is (a dictionary).

Invoked with: <ttnn._ttnn.operations.creation.arange_t object at 0x7fcf9895b230>, 4; kwargs: end=100, step=3, device=<ttnn._ttnn.deprecated.device.Device object at 0x7fcf58757970>

Moreover, ttnn ops generally follow PyTorch function signatures. However, as per the spec and the error messages, ttnn.arange takes stop instead of end, but the op conversion tries to fill in end (using PyTorch naming).

arange(start: int = 0, stop: int, step: int = 1, dtype: ttnn.DataType = ttnn.bfloat16, device: ttnn.Device = None, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG

-- https://docs.tenstorrent.com/tt-metal/latest/ttnn/ttnn/ttnn/arange.html

ayerofieiev-tt commented 2 months ago

It is great to track such issues, but it is trivial to remap arguments and enable lowering

jdh8 commented 2 months ago

This issue is no longer valid. I got the errors because I was using an old version of tt-metal. Everything was fixed when I updated tt-metal.