tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
399 stars 50 forks source link

[Bug Report] Conv2d allocated buffers clashing with L1 buffers #12790

Closed LPanosTT closed 15 hours ago

LPanosTT commented 1 day ago

Describe the bug The convolution is:

To Reproduce

Run the following ttnn pytest:

@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
def test_conv(device):
    batch_size = 1
    output_channels = 1024
    input_channels = 512
    input_height = 28
    input_width = 28
    filter_height = 1
    filter_width = 1
    stride_h = 2
    stride_w = 2
    groups = 1
    pad_h = 0
    pad_w = 0
    dilation = 1
    activations_dtype = ttnn.bfloat16
    weights_dtype = ttnn.bfloat16
    has_bias = False

    input_torch = torch.rand(batch_size, input_channels, input_height, input_width)
    weight_torch = torch.rand(output_channels, input_channels, filter_height, filter_width)

    ground_truth = torch.nn.functional.conv2d(
        input_torch,
        weight_torch,
        bias=None,
        stride=(stride_h, stride_w),
        padding=(pad_h, pad_w),
        dilation=(dilation, dilation),
        groups=groups,
    ).transpose(-3, -2).transpose(-2, -1) # For comparison with channel-last conv from ttnn

    # Format input as channel last for ttnn

    tt_input_tensor = ttnn.from_torch(
        input_torch.transpose(-3, -2).transpose(-2, -1), activations_dtype,
    )

    tt_weight_tensor = ttnn.from_torch(
        weight_torch, weights_dtype
    )

    [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
        input_tensor=tt_input_tensor,
        weight_tensor=tt_weight_tensor,
        in_channels=input_channels,
        out_channels=output_channels,
        device=device,
        bias_tensor=None,
        kernel_size=(filter_height, filter_width),
        stride=(stride_h, stride_w),
        padding=(pad_h, pad_w),
        dilation=(dilation, dilation),
        batch_size=batch_size,
        input_height=input_height,
        input_width=input_width,
        # conv_config=conv_config,
        # conv_op_cache=reader_patterns_cache,
        # debug=debug,
        groups=groups,
    )

    tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
    torch_output_tensor = ttnn.to_torch(tt_output_tensor).reshape(batch_size, out_height, out_width, output_channels)

    pcc = 0.99
    assert_with_pcc(torch_output_tensor, ground_truth, pcc=pcc)

Additional context Tensorflow resnet50 contains this conv. This is blocking the bringup through forge-fe --> MLIR --> ttnn runtime

nsmithtt commented 1 day ago

@LPanosTT, this means that we ran out of L1 so we need to change sharding strategy probably to Block sharded. Dropping precision on the weights might work too. Another thing is we might have fragmented memory so that also could lead to this error, but given that we can repro in isolation as a single op makes me think that sharding strategy is most likely candidate.

I don't think this is a bug with the op, unless we think that it should have automatically picked block sharding.

nsmithtt commented 1 day ago

@LPanosTT I think we want to dump the output memory config that the op picks to double check this is something sane. It might have picked something pretty non-optimal by default.

mywoodstock commented 1 day ago

@LPanosTT Please use BLOCK_SHARDING in this case since the height is small (cannot do more than 7 cores), but is quite wide. It passes for me with BLOCK_SHARDING.

conv_config = ttnn.Conv2dConfig(
        shard_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED,
)
tt-mpantic commented 22 hours ago

@nsmithtt on your note above: "unless we think that it should have automatically picked block sharding.".
Maybe this a good place to ask this question. Can op adapt and choose what is required to be functional ?

I believe ideally compiler should play with specific op config overrides just in order to boost perf (and expectation from op would be that it can run with default/potentially low perf implementation)?

nvukobratTT commented 17 hours ago

@tt-mpantic: ... I believe ideally compiler should play with specific op config overrides just to boost perf (and expectation from op would be that it can run with default/potentially low perf implementation). ...

I agree with this statement! It'll be good from a generality perspective to just have an op path (configuration) that works, without the need for special casing for specific sharding depending on weight or activation shapes.

@nsmithtt: ...this means that we ran out of L1 so we need to change sharding strategy probably to Block sharded. Dropping precision on the weights might work too. ...

Please remind me, @LPanosTT did you hit some bigger issues when pushing conv op to work on DRAM? That way we'll for sure escape memory management issues and sharding.

LPanosTT commented 17 hours ago

Please remind me, @LPanosTT did you hit some bigger issues when pushing conv op to work on DRAM? That way we'll for sure escape memory management issues and sharding.

So far.... no.

@mywoodstock Using block sharding for this conv worked. Thanks!

mywoodstock commented 16 hours ago

shall we close this then @LPanosTT ?

LPanosTT commented 15 hours ago

@mywoodstock Yes.

nsmithtt commented 11 hours ago

@nsmithtt on your note above: "unless we think that it should have automatically picked block sharding.". Maybe this a good place to ask this question. Can op adapt and choose what is required to be functional ?

Unfortunately I don't think this is possible, for the exact case that this issue is covering. There are situations where either your memory is fragmented or L1 is very full and you don't know that you'll run out of mem until you actually invoke the op at which point it's too late.

LPanosTT commented 11 hours ago

@nsmithtt can we not block shard in dram?

nsmithtt commented 8 hours ago

@mywoodstock, is DRAM sharded supported now?

@LPanosTT, conv cannot stream activations from dram, but it can with the weights.

nvukobratTT commented 18 minutes ago

@nsmithtt @mywoodstock Just to get a bit more clarity on my side.

Is it possible to utilize conv2d op, without a specific sharding? E.g.

  1. Activation and weights are pushed from the host to DRAM
  2. conv2d fetches tensors from DRAM and runs
  3. conv2d pushes outputs back to the DRAM
  4. Next op is picking up data from DRAM

My question here is can we run convs without any sharding requirement in L1? if not, is this by design? Not to utilize DRAM for activations? If that isn't the case, should we treat this as a bug?

Thanks for providing more context! :))