Closed aguhaTT closed 5 months ago
repeat_interleave
result is consumed by elementwise mul. Adding @kpaigwar and @esmalTT .could it be possible to replace
repeat_interleave + mul
by a broadcast mul?
it is basically a bunch of multiplications of a vector by a scalar. A
is a vector of size 32
. You can think of delta
as 5120
scalars. A
needs to be multiplied by each of these scalars.
Let's sync offline.
I would like this issue to be cross-referenced with https://github.com/tenstorrent-metal/tt-metal/issues/6939 because I think these two problems are linked and need to be solved together.
Adding more details requested by @TT-BrianLiu @tarafdarTT
repeat_interleave works on device for only dim=0, 1 @mywoodstock has suggested two workarounds before which we have tried out
upsample
permute, repeat_interleave, permute
Here is the fresh perf data you were looking for
34ms
with program_caching enabledpermute, upsample, permute
Unit Test
upsample.xlsx1.7 sec
Unit Test
repeat_interleave_hack.xlsxgood job @kpaigwar . does this mean the upsample is good enough with 34 ms?
good job @kpaigwar . does this mean the upsample is good enough with 34 ms?
Not really, as per our target of 1ms for full decoder.
tilize, transpose, untilize
from? For 1, looks like tilize, transpose, untilize
coming from permute.
Yeah correct, program caching should not matter here.
Also, upsample itself takes 13ms.
Upsample requires permute as well, or just repeat_interleaved?
Upsample requires permute as well, or just repeat_interleaved?
Currently the Upsample API doesn't allow scaling in the C dim. Hence, has to use permute
@tarafdarTT @TT-BrianLiu wanted to check if there is any update on this issue. Are you planning to add support for this op?
Have you explored ROW_MAJOR eltwise_mul with bcast on W?
Have you explored ROW_MAJOR eltwise_mul with bcast on W?
Yes, I tried that. I get this error
RuntimeError: TT_FATAL @ tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp:84: input_tensor_b.get_legacy_shape()[3] == 1 || input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH
when running this
import ttnn
import torch
from tests.ttnn.utils_for_testing import assert_with_pcc
import tt_lib as ttl
device_id = 0
device = ttnn.open_device(device_id=device_id)
torch.manual_seed(0)
input = torch.randn((1, 1, 32, 5120), dtype=torch.bfloat16)
## reference
torch_result = torch.repeat_interleave(input, (32), dim=3)
# prepare inputs on tt_device
bcast_mask = torch.ones((1, 1, 32, 32*5120), dtype=torch.bfloat16)
tt_input = ttnn.from_torch(input, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tt_bcast_mask = ttnn.from_torch(bcast_mask, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
# run tt_lib.tensor.bcast on dim=W
tt_output = ttl.tensor.bcast(tt_bcast_mask, tt_input, math_op=ttl.tensor.BcastOpMath.MUL, dim=ttl.tensor.BcastOpDim.W)
repeat_interleaved_output = ttnn.to_torch(tt_output)
assert_with_pcc(torch_result, repeat_interleaved_output)
print("PASS1")
I'm out of office April 10, can we sync on this April 11. To clarify we have confirmed that it no longer fallback, but @kpaigwar is looking at implementing this with bcast?
I'm out of office April 10, can we sync on this April 11. To clarify we have confirmed that it no longer fallback, but @kpaigwar is looking at implementing this with bcast?
Sure, We can sync on April 11. I actually already tried implementing using eltwise bcast mul but looks like it doesnot support multicasting.
Update: We confirm that tt_lib.repeat_interleave on dim=3 works on device however executing the same from ttnn does fallback. We need to update the ttnn op to take care of it. Reassigning to ttnn developers https://github.com/tenstorrent/tt-metal/blob/06828e32f0583392a99de71fb56128ee5af3ceff/ttnn/ttnn/operations/data_movement.py#L443
@tarafdarTT can you please make this change? And a ttnn
unit test?
this issue is potentially not needed as well if idea for custom op by @TT-BrianLiu and @rtawfik01 works.
we can keep it in p2 or close it. Up to you.
repeat_interleave
is one our biggest ops in mamba. We need it to work on device.use cases: