tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
59 stars 7 forks source link

Eltwise Binary between sharded and interleaved tensors does not work. #738

Open LPanosTT opened 1 week ago

LPanosTT commented 1 week ago

In Resnet50, the first convolution is followed by a multiply op. The output tensor from conv2d is shared while the constant it is to be multiplied with is not. This causes TTNN to attempt broadcasting on the constant, I think by assuming the constant is sharded. This causes a bad optional access on line 73 of

third_party/tt-mlir/third_party/tt-metal/src/tt-metal/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_sharded_program_factory.cpp

This attempts to get the shard spec of a tensor which does not have one.

To repro, run this test in forge:

def test_resnet_first_conv2d():

    tf.random.set_seed(0)

    in_c = 3
    out_c = 64
    batch_size = 1
    input_height = 112
    input_width = 112
    stride = (2, 2)
    padding = ((3, 3), (3, 3))
    dilation = 1

    class ResnetFirstConv2d(tf.keras.Model):
        def __init__(self):
            super().__init__()
            self.pad = tf.keras.layers.ZeroPadding2D(padding=((3, 3), (3, 3)))
            self.conv = tf.keras.layers.Conv2D(out_c, (7, 7), strides=2, use_bias=False, padding="valid")
            self.const = tf.random.uniform((64,))

        def call(self, x):
            x = self.pad(x)
            x = self.conv(x)
            return x * self.const

    input_shape = (batch_size, input_height, input_width, in_c)
    inputs = [tf.random.uniform(input_shape)]

    framework_model = ResnetFirstConv2d()
    fw_out = to_pt_tensors(framework_model(*inputs))

    compiled_model = forge.compile(framework_model, sample_inputs=inputs)
    co_out = compiled_model(*inputs)
    co_out = [co.to("cpu").to(fw_out[0].dtype) for co in co_out]
    assert compare_tensor_to_golden("dual_conv2d", fw_out[0], co_out[0].reshape(fw_out[0].shape))

There is a solution that fixes this. That is to convert the conv2d output to interleaved. I can do this by placing this code in the runtime execution function of conv2d, on the conv2d's output:

auto new_memconfig = out.memory_config();
new_memconfig.memory_layout = TensorMemoryLayout::INTERLEAVED;
auto sharded_to_interleaved = ::ttnn::operations::data_movement::ShardedToInterleavedOperation();
out = sharded_to_interleaved.invoke(0, out, new_memconfig, out.dtype());

This isn't ideal because it's not necessary in all cases. For example, if a maxpool2d op immediately follows the conv2d, leaving the conv2d output as sharded works just fine. @nvukobratTT when we come up with a solution for this I might want to include some of those test cases in this PR: https://github.com/tenstorrent/tt-forge-fe/pull/304

@nsmithtt @nvukobratTT thoughts?

LPanosTT commented 1 week ago

On another note. When I hardcode this interleave op I can actually get resnet50 to execute half of the model. There's a conv2d op halfway through that causes a circular buffer - L1 clash. Metal issue here: https://github.com/tenstorrent/tt-metal/issues/12790

nsmithtt commented 1 week ago

@LPanosTT, can you link the issue that this one is blocked by?

Also I think this should be supported, let's sync with someone on TTNN side to figure out how to supply a bcast'd sharded input. This should be possible, we can take a look at their implementation of resnet for reference too, I think here: models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py

LPanosTT commented 1 week ago

can you link the issue that this one is blocked by?

@nsmithtt There isn't an issue this is blocking, more so the resnet50 bringup milestone in tt-forge: https://github.com/tenstorrent/tt-forge-fe/issues/137#issue-2475043862