tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
416 stars 53 forks source link

Need support for ttnn.max_pool2d to accept block and width sharded input. #12810

Open punithsekar opened 1 week ago

punithsekar commented 1 week ago

Describe the bug ttnn.max_pool2d supports only height_sharded input tensor. Need support for block_sharded and width_sharded input.

To Reproduce Steps to reproduce the behavior:

  1. Checkout to branch punith/maxpool_issue
  2. Run command pytest tests/ttnn/integration_tests/yolov4/test_ttnn_neck.py

Expected behavior To accept Block_sharded and width_sharded layout.

Screenshots

E       RuntimeError: TT_FATAL @ ../ttnn/cpp/ttnn/operations/pool/maxpool/max_pool2d.cpp:58: shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED
E       info:
E       Only height sharded tensors are supported.
E       backtrace:
E        --- /home/ubuntu/punith/tt-metal/ttnn/ttnn/_ttnn.so(+0x458ec9) [0x7f090cd77ec9]

Please complete the following environment information:

Additional context The input shape which we pass to maxpool is 1,10,10,512[NHWC]. Since the Channels is higher it should happen in width or block sharding to increasing the performance.

Current values when we use height sharding,

pool_1 = ttnn.max_pool2d(
            input_tensor=output_tensor,
            batch_size=1,
            input_h=10,
            input_w=10,
            channels=512,
            kernel_size=[5, 5],
            stride=[1, 1],
            padding=[2, 2],
            dilation=[1, 1],
            device=device,
        )

Attributes: {'memoryconfig':'MemoryConfig(memory_layout=TensorMemoryLayout::HEIGHT_SHARDED;buffer_type=BufferType::L1;shard_spec=ShardSpec(grid={[(x=0;y=0) - (x=3;y=0)]};shape={25; 0};orientation=ShardOrientation::ROW_MAJOR;halo=0))'; 'outputdtype': 'DataType::BFLOAT16'; 'sliding_windowconfig': 'SlidingWindowConfig(batch_size=1; input_hw=(10;10); window_hw=(5;5); stride_hw=(1;1); pad_hw=(2;2); dilation_hw=(1;1); num_cores_nhw=4; core_rangeset={[(x=0;y=0) - (x=3;y=0)]})'}

Core_count: 4

Kernel duration: 1077197 ns

punithsekar commented 1 week ago

fyi @saichandax