tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
430 stars 59 forks source link

repeat_interleave falls back to cpu #6516

Closed aguhaTT closed 5 months ago

aguhaTT commented 6 months ago

repeat_interleave is one our biggest ops in mamba. We need it to work on device.

use cases:

1. repeat_interleave((1,1,32,5120),32,dim=3)
2. repeat_interleave((1,1,1,5120),32,dim=2)
aguhaTT commented 6 months ago

repeat_interleave result is consumed by elementwise mul. Adding @kpaigwar and @esmalTT .

could it be possible to replace repeat_interleave + mul by a broadcast mul?

it is basically a bunch of multiplications of a vector by a scalar. A is a vector of size 32. You can think of delta as 5120 scalars. A needs to be multiplied by each of these scalars.

TT-BrianLiu commented 6 months ago

Let's sync offline.

aguhaTT commented 6 months ago

I would like this issue to be cross-referenced with https://github.com/tenstorrent-metal/tt-metal/issues/6939 because I think these two problems are linked and need to be solved together.

kpaigwar commented 6 months ago

Adding more details requested by @TT-BrianLiu @tarafdarTT

repeat_interleave works on device for only dim=0, 1 @mywoodstock has suggested two workarounds before which we have tried out

  1. Use upsample
  2. do permute, repeat_interleave, permute

Here is the fresh perf data you were looking for

Upsample workaround

Repeat_interleave in dim=0

ntarafdar commented 6 months ago

good job @kpaigwar . does this mean the upsample is good enough with 34 ms?

kpaigwar commented 6 months ago

good job @kpaigwar . does this mean the upsample is good enough with 34 ms?

Not really, as per our target of 1ms for full decoder.

TT-BrianLiu commented 6 months ago
kpaigwar commented 6 months ago

For 1, looks like tilize, transpose, untilize coming from permute. Yeah correct, program caching should not matter here. Also, upsample itself takes 13ms.

TT-BrianLiu commented 6 months ago

Upsample requires permute as well, or just repeat_interleaved?

kpaigwar commented 6 months ago

Upsample requires permute as well, or just repeat_interleaved?

Currently the Upsample API doesn't allow scaling in the C dim. Hence, has to use permute

kpaigwar commented 6 months ago

@tarafdarTT @TT-BrianLiu wanted to check if there is any update on this issue. Are you planning to add support for this op?

TT-BrianLiu commented 6 months ago

Have you explored ROW_MAJOR eltwise_mul with bcast on W?

kpaigwar commented 6 months ago

Have you explored ROW_MAJOR eltwise_mul with bcast on W?

Yes, I tried that. I get this error RuntimeError: TT_FATAL @ tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp:84: input_tensor_b.get_legacy_shape()[3] == 1 || input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH

when running this

import ttnn
import torch
from tests.ttnn.utils_for_testing import assert_with_pcc
import tt_lib as ttl
device_id = 0
device = ttnn.open_device(device_id=device_id)
torch.manual_seed(0)

input = torch.randn((1, 1, 32, 5120), dtype=torch.bfloat16)

## reference
torch_result = torch.repeat_interleave(input, (32), dim=3)

# prepare inputs on tt_device
bcast_mask = torch.ones((1, 1, 32, 32*5120), dtype=torch.bfloat16)
tt_input = ttnn.from_torch(input, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_bcast_mask = ttnn.from_torch(bcast_mask, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
# run tt_lib.tensor.bcast on dim=W
tt_output = ttl.tensor.bcast(tt_bcast_mask, tt_input, math_op=ttl.tensor.BcastOpMath.MUL, dim=ttl.tensor.BcastOpDim.W)
repeat_interleaved_output = ttnn.to_torch(tt_output)
assert_with_pcc(torch_result, repeat_interleaved_output)
print("PASS1")
ntarafdar commented 6 months ago

I'm out of office April 10, can we sync on this April 11. To clarify we have confirmed that it no longer fallback, but @kpaigwar is looking at implementing this with bcast?

kpaigwar commented 6 months ago

I'm out of office April 10, can we sync on this April 11. To clarify we have confirmed that it no longer fallback, but @kpaigwar is looking at implementing this with bcast?

Sure, We can sync on April 11. I actually already tried implementing using eltwise bcast mul but looks like it doesnot support multicasting.

kpaigwar commented 5 months ago

Update: We confirm that tt_lib.repeat_interleave on dim=3 works on device however executing the same from ttnn does fallback. We need to update the ttnn op to take care of it. Reassigning to ttnn developers https://github.com/tenstorrent/tt-metal/blob/06828e32f0583392a99de71fb56128ee5af3ceff/ttnn/ttnn/operations/data_movement.py#L443

arakhmati commented 5 months ago

@tarafdarTT can you please make this change? And a ttnn unit test?

aguhaTT commented 5 months ago

this issue is potentially not needed as well if idea for custom op by @TT-BrianLiu and @rtawfik01 works.

jliangTT commented 5 months ago

we can keep it in p2 or close it. Up to you.