tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
467 stars 73 forks source link

Need bfloat8_b support for maxpool #12926

Closed punithsekar closed 1 month ago

punithsekar commented 1 month ago

Describe the bug ttnn.max_pool2d returns bfloat16 dtype even though we pass bfloat8_b as input.

To Reproduce Steps to reproduce the behavior: Run the following code snippet

import torch
import ttnn
import pytest

@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_maxpool(device):
    input_a=torch.randn(1,1,100,512)
    input_a=ttnn.from_torch(input_a,device=device,dtype=ttnn.bfloat8_b,layout=ttnn.TILE_LAYOUT)

    pool_2 = ttnn.max_pool2d(
            input_tensor=input_a,
            batch_size=1,
            input_h=10,
            input_w=10,
            channels=512,
            kernel_size=[9, 9],
            stride=[1, 1],
            padding=[4, 4],
            dilation=[1, 1],
            device=device,
        )
    print("pool2",pool_2.dtype)

Expected behavior To return output of dtype bfloat8_b.

Please complete the following environment information:

punithsekar commented 1 month ago

fyi @saichandax

mywoodstock commented 1 month ago

@dvartaniansTT maxpool is by design returning bfloat16 since its output is Row Major. If BFP8_B is needed, we can add a tilize op right after.

dvartaniansTT commented 1 month ago

thanks @mywoodstock. This was filed by MCW. The intention was to be able to use bfp8_b to squeeze more perf.

mywoodstock commented 1 month ago

OK, since maxpool implementation works with row major data, the output will be bfloat16. If TILEd data is needed, we will need to implement new kernels for that format, and might be quite tricky -- not sure if its worth it.

punithsekar commented 1 month ago

@mywoodstock , The intention is to increase the performance as Dalar mentioned. And, I also observe there is a PCC drop for the whole model if we typecast to bfloat8_b after maxpool completion.

I will try to reproduce the pcc drop issue and attach it here.

punithsekar commented 1 month ago

I tried to recreate the PCC drop issue today but I am not facing it now. We can close the ticket if bf8 support cannot be given to maxpool as it uses rowmajor layout.

mywoodstock commented 1 month ago

closing as maxpool works with RM only.