Open boris-drazic opened 1 year ago
We already have the TT implementation for the Llama model. The bringup of the existing Llama model is at PR #1792.
The only difference between the Llama & Llama2 models is with respect to the training data. Llama2 is trained on 40% more data than Llama and have double the context length.
Should we still consider the development of the TT implementation for Llama2 separately or incorporate it in the Llama's TT implementation with invoking the model with Llama2 state_dict?
Please suggest what would be ideal to proceed with. Thank you.
For 13B and 65B, Llama-2 uses grouped query attention (GQA) and has longer sequence length (4096)
GQA vs MHA is a big difference
Tested the GS Demo, after updating the TtSelfAttention Module. In Progress to run individual test for SelfAttention and report numbers.
The Models sheet is updated with the details of the Llama2 model TT implementation.
Waiting for the Llama1 PR #1792 to be merged, to update the Llama2 variant on top Llama Model.
Llama1 PR has been updated to #2348. Once it is merged, the update on Llama2 will be pushed on a new PR.
As per the latest progress on the ticket, We are working on Llama2 TT implementation (based on Meta's LLama2 reference implementation).
The CPU Demo is under the ticket: #2615
@Sudharsan-V , Please add any further details about the progress on this task.
The TT implementation of the Llama-2 sub-modules is completed and the pcc, ATOL, and rtol values of all sub-modules are updated in the sheet. Since the memory is not enough to load the entire model, the model is split into 4 parts.
Commit for the tt-llama2 model is available here
Note: The commit available is not the final version.
@boris-drazic and @uaydonat are we doing Llama2-7B ?
llama2-7B is not part of the project javelin, so we are not working on it. although it would be quite straightforward to bring it up if we needed it.
Implement Llama-2 (
meta-llama/Llama-2-7b-hf
) using TT ops and TT fallback ops. (PyTorch demo is in issue https://github.com/tenstorrent-metal/tt-metal/issues/1846.)Make individual tests for all submodules that are used in the model starting from smaller submodules and working to larger submodules that use smaller submodules.
Each submodule test should use real weights from trained model and randomly generated inputs of correct shape.
Test should assert on PCC between golden output from PyTorch for submodule and TT output from module.