tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
https://docs.tenstorrent.com/ttnn/latest/index.html
Apache License 2.0
486 stars 80 forks source link

Llama-2 (meta) TT implementation #1847

Open boris-drazic opened 1 year ago

boris-drazic commented 1 year ago

Implement Llama-2 (meta-llama/Llama-2-7b-hf) using TT ops and TT fallback ops. (PyTorch demo is in issue https://github.com/tenstorrent-metal/tt-metal/issues/1846.)

Make individual tests for all submodules that are used in the model starting from smaller submodules and working to larger submodules that use smaller submodules.

Each submodule test should use real weights from trained model and randomly generated inputs of correct shape.

Test should assert on PCC between golden output from PyTorch for submodule and TT output from module.

saichandax commented 1 year ago

We already have the TT implementation for the Llama model. The bringup of the existing Llama model is at PR #1792.

The only difference between the Llama & Llama2 models is with respect to the training data. Llama2 is trained on 40% more data than Llama and have double the context length.

Should we still consider the development of the TT implementation for Llama2 separately or incorporate it in the Llama's TT implementation with invoking the model with Llama2 state_dict?

Please suggest what would be ideal to proceed with. Thank you.

davorchap commented 1 year ago

For 13B and 65B, Llama-2 uses grouped query attention (GQA) and has longer sequence length (4096)

GQA vs MHA is a big difference

saichandax commented 1 year ago

Tested the GS Demo, after updating the TtSelfAttention Module. In Progress to run individual test for SelfAttention and report numbers.

saichandax commented 1 year ago

The Models sheet is updated with the details of the Llama2 model TT implementation.

Waiting for the Llama1 PR #1792 to be merged, to update the Llama2 variant on top Llama Model.

saichandax commented 1 year ago

Llama1 PR has been updated to #2348. Once it is merged, the update on Llama2 will be pushed on a new PR.

saichandax commented 1 year ago

As per the latest progress on the ticket, We are working on Llama2 TT implementation (based on Meta's LLama2 reference implementation).

The CPU Demo is under the ticket: #2615

@Sudharsan-V , Please add any further details about the progress on this task.

Sudharsan-V commented 1 year ago

The TT implementation of the Llama-2 sub-modules is completed and the pcc, ATOL, and rtol values of all sub-modules are updated in the sheet. Since the memory is not enough to load the entire model, the model is split into 4 parts.

Commit for the tt-llama2 model is available here

Note: The commit available is not the final version.

davorchap commented 8 months ago

@boris-drazic and @uaydonat are we doing Llama2-7B ?

uaydonat commented 8 months ago

llama2-7B is not part of the project javelin, so we are not working on it. although it would be quite straightforward to bring it up if we needed it.