tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
380 stars 47 forks source link

Splitting and Merging feature map on device [Feature Request] #7566

Open athul-bos-semi opened 4 months ago

athul-bos-semi commented 4 months ago

When I tried to perform ttnn.MaxPool2d operation on High-resolution feature map, the L1 memory ran out. So I instead split the feature map into four quarters and performed ttnn.MxPool2d on each of them, then combined their feature maps. But it is not possible to do the splitting and merging operations on device, hence I had to do them on host.

Alternative solution

import torch
import ttnn

from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_allclose_and_pcc

batch_size = 1
in_channels = 64
input_height, input_width = 640,480  # Output of First Convolution layer (+BatchNorm2d+ReLU) of ResNet
image = torch.rand((batch_size,in_channels,input_height,input_width))

def max_pool_2d(x, dtype, device):
    pad = 1
    stride = 2
    kernel_size = 3
    MAX_POOL_SPLIT_CONSTANT = 2

    x = torch.nn.functional.pad(x, (1,1,1,1), "constant", 0)
    new_shape = [x.shape[-2],x.shape[-1]]

    output_shape = [(new_shape[0]-kernel_size)//stride+1,(new_shape[1]-kernel_size)//stride+1]
    output = torch.zeros((x.shape[0],x.shape[1],output_shape[-2],output_shape[-1])).bfloat16()

    max_pool = ttnn.MaxPool2d(
                                kernel_size=(kernel_size,kernel_size),
                                stride=(stride,stride),
                                padding=(0, 0),
                                dilation=(1, 1),
                                dtype=dtype,
                                device=device,
                                batch_size=1,
                                input_height=(output_shape[0]//MAX_POOL_SPLIT_CONSTANT*stride) + (kernel_size-stride),
                                input_width=(output_shape[1]//MAX_POOL_SPLIT_CONSTANT*stride) + (kernel_size-stride),
                                reader_patterns_cache={},
                            )

    for i in range(MAX_POOL_SPLIT_CONSTANT):
        for j in range(MAX_POOL_SPLIT_CONSTANT):
            window = [[i * output_shape[0]//MAX_POOL_SPLIT_CONSTANT*stride,
                            (i+1) * (output_shape[0]//MAX_POOL_SPLIT_CONSTANT*stride) + (kernel_size-stride)],
                      [j * output_shape[1]//MAX_POOL_SPLIT_CONSTANT*stride,
                            (j+1) * (output_shape[1]//MAX_POOL_SPLIT_CONSTANT*stride) + (kernel_size-stride)]]
            x_split = x[:,:,window[0][0]:window[0][1],window[1][0]:window[1][1]].permute(0,2,3,1)

            x_split = ttnn.from_torch(x_split, ttnn.bfloat16)
            x_split = max_pool.copy_input_to_device(x_split)
            x_split = max_pool(x_split)
            x_split = max_pool.copy_output_from_device(x_split)

            x_split = x_split.to_torch().permute(0,3,1,2)
            x_split = x_split.reshape(1,64,160,120)
            window = [[i*output_shape[0]//MAX_POOL_SPLIT_CONSTANT,
                        (i+1)*output_shape[0]//MAX_POOL_SPLIT_CONSTANT],
                      [j*output_shape[1]//MAX_POOL_SPLIT_CONSTANT,
                        (j+1)*output_shape[1]//MAX_POOL_SPLIT_CONSTANT]]
            output[:,:,window[0][0]:window[0][1],window[1][0]:window[1][1]] = x_split

    return output

if __name__ == "__main__":

    inputs = image.bfloat16()

    Nvidia_ResNet50 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
    parameters = Nvidia_ResNet50.state_dict()
    # for key, value in parameters.items():
    #     print(key)

    device_id = 2
    device = ttnn.open_device(device_id)
    dtype = ttnn.float32
    # dtype = ttnn.bfloat8_b
    # dtype = ttnn.bfloat16

    output = max_pool_2d(inputs, dtype, device)

    golden_maxpool = torch.nn.MaxPool2d(3,2,1)
    golden_output = golden_maxpool(inputs)

    pcc = 0.998
    passing_pcc, info = comp_pcc(output, golden_output, pcc=pcc)
    print("Info = ", info)

    ttnn.close_device(device)

Alternatives considered Copying from and into a tensor slice would solve this issue, as described in Feature Request #7045. Copying from a tensor can be accomplished using slicing.

davorchap commented 4 months ago

@jliangTT for the TM team to triage , this should be part of interleave2sharded , sharded2interleave work

davorchap commented 4 months ago

fyi @tarafdarTT @yan-zaretskiy

jliangTT commented 4 months ago

@davorchap , please advice on priority

mbahnasTT commented 4 months ago

@davorchap , should @athul-bos-semi use the partial sharding code like in https://github.com/tenstorrent/tt-metal/blob/main/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py#L764?