tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
471 stars 74 forks source link

TTNN `pad` does not support last dimension #12896

Open esmalTT opened 1 month ago

esmalTT commented 1 month ago

Summary

Pad on device is required to improve end-to-end performance of UNet Shallow. The sharded input tensor needs to be padded from 4 -> 16 channels.

Running a pad on a device tensor of shape {1, 1, 2 * 1056 * 160, 4} to {1, 1, 2 * 1056 * 160, 16} throws the following error:

sharded pad does not support pad on last dim currently as that will cause perf degradation
jaykru-tt commented 1 month ago

@esmalTT Can you give an example test? I have been working on this for a bit and I'm actually now confused how you obtained a sharded tensor of that shape in the first place :)

esmalTT commented 1 month ago

@esmalTT Can you give an example test? I have been working on this for a bit and I'm actually now confused how you obtained a sharded tensor of that shape in the first place :)

I’m away from my computer until next week (so sadly I can’t run this) but this is what I need for UNet


import pytest
import torch
import ttnn

def test_unet_pad(device, use_program_cache):
    x = ttnn.from_torch(torch.rand([1, 1, 337920, 4]), dtype=ttnn.bfloat16)

    sharded_memory_config = ttnn.create_sharded_memory_config(
        [1, 1, 337920, 4], ttnn.CoreGrid(x=8, y=8), ttnn.ShardStrategy.HEIGHT
    )

    x = ttnn.to_device(x, device, sharded_memory_config)

    x = ttnn.pad(x, ((0,0),(0,0),(0,0),(0,12))) # pad up to 16
jaykru-tt commented 1 month ago

This is perfect, thanks @esmalTT :)