tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
300 stars 23 forks source link

[Llama2] support DRAM sharded matmuls #9642

Open cglagovichTT opened 4 days ago

cglagovichTT commented 4 days ago

Decode passes with FF2, FF3 dram sharded weights. We are leaving FF1 as-is because dram-sharded FF1 is slower than before due to SILU activation.

There is an issue with prefill matmul2D with dram sharded weights when in0 is batched and interleaved. @yugaoTT is looking into it.

branch: cglagovich/9642 repro:

pytest -svv tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_dram_sharded.py::test_matmul_2d_in1_dram_sharded[1024-8192-4096-None-no_bias-LoFi-fp32-no_pack_l1]
cglagovichTT commented 4 days ago

fyi @avoraTT @mikevin920

yugaoTT commented 4 days ago

@cglagovichTT for FF1, should we modify Eltwise-binary op to take Silu as input activation?

cglagovichTT commented 4 days ago

Yes @yugaoTT that would help a lot. I'd like to be able to do mul with SILU on in0

avoraTT commented 1 day ago

@kevinmiTT11 please update weight caches on CI machines. Then we should be good to merge dram sharded decode in.

cglagovichTT commented 1 day ago

Might be good to change name for these weights to avoid conflict