Closed Force1ess closed 12 months ago
Hey! Thank you for your question. For long context training, we have used the JAX code which is more optimized for training with longer input sequences. The PyTorch code was initially intended for inference, however, if we find time then we will happily optimize it further. The two main optimizations in jax code are
scan_cross_batch
- the main idea is not to materialize the whole ∇ Query x KeyThank u very much for your patient reply It helped a lot
Hi, long_llama is very surprising according to the report in the paper, and thank u for ur great work. I re I'm interested in training Long_Llama-3b on some long text corpus. But out of memory error is very usual on my A100-80G. Is there any solutions to finetune this model on text of 10k length? Do you have any idea about reducing memory usage? I noticed in your paper that you mentioned the model was trained at a length of 8k. Can u share ur training script so I can learn from it?
Below is my trainning script