Currently, we put #type: ignore at the top of the train_llama.py file.
This should be remove and all typing errors should be fixed
Using jaxtyping (e.g. Float[Tensor, "batch seq d_embd"]) is desired, even if it's purely for documentation and not actually used in runtime checking (I've found runtime checking typically makes the documentation a little worse).