tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
471 stars 74 forks source link

[CCL] Support rank 2 tensor for reduce_scatter TTNN cpp api #15010

Open wooseokTT opened 3 days ago

wooseokTT commented 3 days ago

Current TTNN call to reduce_scatter with 2-dim tensor causes runtime error as follows. The error can be alleviated with front and back reshape operation that changes dimension from 2 to 4 and 4 to 2, respectively.

Original MLIR op description that causes the problem.

%9 = "ttnn.reduce_scatter"(%8) <{math_op = #tt.reduce_type, num_links = 1 : si32, scatter_split_dim = 1 : si32}> : (tensor<8192x16384xf32, #layout7>) -> tensor<8192x8192xf32, #layout7>

Updated MLIR ops with reshape front and back.

%9 = "ttnn.reshape"(%8) <{shape = [1 : i32, 1 : i32, 8192 : i32, 16384 : i32]}> : (tensor<8192x16384xf32, #layout7>) -> tensor<1x1x8192x16384xf32, #layout7> %10 = "ttnn.reduce_scatter"(%9) <{math_op = #tt.reduce_type, num_links = 1 : si32, scatter_split_dim = 3 : si32}> : (tensor<1x1x8192x16384xf32, #layout7>) -> tensor<1x1x8192x8192xf32, #layout7> %11 = "ttnn.all_gather"(%10) <{dim = 3 : si32, num_links = 1 : si32}> : (tensor<1x1x8192x8192xf32, #layout7>) -> tensor<1x1x8192x16384xf32, #layout7> %12 = "ttnn.reshape"(%11) <{shape = [8192 : i32, 16384 : i32]}> : (tensor<1x1x8192x16384xf32, #layout7>) -> tensor<8192x16384xf32, #layout7>


2024-11-13 16:59:57,311 - ERROR - ERROR: test=./test_ttnn.ttnn experienced an error with exception=TT_FATAL @ /proj_sw/user_dev/wooseoklee/tt-mlir/third_party/tt-metal/src/tt-metal/ttnn/cpp/ttnn/tensor/types.cpp:239: normalized_index >= 0 and normalized_index < rank info: Index is out of bounds for the rank, should be between 0 and 1 however is 3 backtrace: --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_ttnn.so(+0x121b448) [0x7e58b812a448] --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_ttnn.so(+0x140d63b) [0x7e58b831c63b] --- ttnn::ccl::RingReduceScatterBaseTensorSlicer::RingReduceScatterBaseTensorSlicer(tt::tt_metal::Tensor const&, tt::tt_metal::Tensor const&, int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int) --- ttnn::ccl::RingReduceScatterWrappedTensorSlicer::RingReduceScatterWrappedTensorSlicer(tt::tt_metal::Tensor const&, tt::tt_metal::Tensor const&, int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int) --- ttnn::ccl::reduce_scatter_detail::reduce_scatter_with_workers(tt::tt_metal::Tensor const&, tt::tt_metal::Tensor const&, ttnn::operations::binary::BinaryOpType, unsigned int, unsigned int, unsigned int, unsigned int, std::optional, std::optional, ttnn::ccl::Topology, std::optional, std::optional) --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_ttnn.so(+0x4b08ea) [0x7e58b73bf8ea] --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_ttnn.so(+0x40e9c0) [0x7e58b731d9c0] --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_ttnn.so(+0x41981d) [0x7e58b732881d] --- void ttnn::device_operation::detail::launch_on_worker_thread<tt::tt_metal::operation::OldInfraDeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > >, unsigned char, long, tt::tt_metal::operation::DeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > >, tt::tt_metal::operation::OldInfraDeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > >::tensor_args_t, std::vector<tt::tt_metal::Tensor, std::allocator >, tt::tt_metal::v0::Device>(unsigned char, long, tt::tt_metal::operation::DeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > > const&, tt::tt_metal::operation::OldInfraDeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > >::tensor_args_t const&, std::vector<tt::tt_metal::Tensor, std::allocator >&, tt::tt_metal::v0::Device&) --- tt::tt_metal::operation::OldInfraDeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > >::tensor_return_value_t ttnn::device_operation::detail::launch_on_single_device<tt::tt_metal::operation::OldInfraDeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > > >(unsigned char, tt::tt_metal::operation::OldInfraDeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > >::operation_attributes_t const&, tt::tt_metal::operation::OldInfraDeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > >::tensor_args_t const&) --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_ttnn.so(+0x40b72e) [0x7e58b731a72e] --- tt::tt_metal::operation::OldInfraDeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > >::tensor_return_value_t ttnn::device_operation::detail::invoke<tt::tt_metal::operation::OldInfraDeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > > >(unsigned char, tt::tt_metal::operation::OldInfraDeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > >::operation_attributes_t const&, tt::tt_metal::operation::OldInfraDeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > >::tensor_args_t const&) --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_ttnn.so(+0x40a27f) [0x7e58b731927f] --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_ttnn.so(+0x4061d5) [0x7e58b73151d5] --- std::vector<tt::tt_metal::Tensor, std::allocator > tt::tt_metal::operation::run<std::vector<tt::tt_metal::Tensor, std::allocator > >(tt::tt_metal::operation::DeviceOperation<std::vector<tt::tt_metal::Tensor, std::allocator > >&&, std::vector<tt::tt_metal::Tensor, std::allocator > const&, std::vector<std::optional, std::allocator<std::optional > > const&, std::vector<std::optional, std::allocator<std::optional > > const&, unsigned char) --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_ttnn.so(+0x4ae769) [0x7e58b73bd769] --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_ttnn.so(+0x4af223) [0x7e58b73be223] --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_ttnn.so(+0x4b00c4) [0x7e58b73bf0c4] --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/libtt_metal.so(+0x172c8a) [0x7e58b6dd1c8a] --- tt::tt_metal::v0::Device::push_work(std::shared_ptr<std::function<void ()> >, bool) --- ttnn::operations::ccl::reduce_scatter(tt::tt_metal::Tensor const&, unsigned int, ttnn::operations::reduction::ReduceType, unsigned int, tt::tt_metal::MemoryConfig const&, ttnn::ccl::Topology, std::optional, std::optional) --- ttnn::operations::ccl::ExecuteReduceScatter::invoke(tt::tt_metal::Tensor const&, unsigned int, ttnn::operations::reduction::ReduceType, unsigned int, std::optional const&, ttnn::ccl::Topology, std::optional, std::optional) --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_C.cpython-310-x86_64-linux-gnu.so(+0xa8293) [0x7e58b979c293] --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_C.cpython-310-x86_64-linux-gnu.so(+0xa7cbb) [0x7e58b979bcbb] --- tt::runtime::ttnn::operations::ccl::run(tt::target::ttnn::ReduceScatterOp const, tt::runtime::ttnn::ProgramContext&) --- tt::runtime::ttnn::runProgram(tt::tt_metal::distributed::MeshDevice&, tt::target::ttnn::Program const, std::vector<tt::tt_metal::Tensor, std::allocator<tt::tt_metal::Tensor> > const&, std::vector<tt::tt_metal::Tensor, std::allocator<tt::tt_metal::Tensor> > const&) --- tt::runtime::ttnn::submit(tt::runtime::Device, tt::runtime::Binary, unsigned int, std::vector<tt::runtime::Tensor, std::allocator > const&, std::vector<tt::runtime::Tensor, std::allocator > const&) --- tt::runtime::submit(tt::runtime::Device, tt::runtime::Binary, unsigned int, std::vector<tt::runtime::Tensor, std::allocator > const&, std::vector<tt::runtime::Tensor, std::allocator > const&) --- /opt/ttmlir-toolchain/venv/lib/python3.10/site-packages/ttrt/runtime/_C.cpython-310-x86_64-linux-gnu.so(+0x6d185) [0x7e58b9761185]

SeanNijjar commented 2 days ago

Marking this as P2 as: a) there is a workaround b) VVL team will revisit this after our upgrade to V2 CCLs (which is currently underway for the next couple weeks). After CCLs migrate, this will become higher priority