NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
9.23k stars 2.08k forks source link

Fix(memory optimization): inplace subtract vocab_parallel_logits #863

Open Andy666G opened 2 weeks ago

Andy666G commented 2 weeks ago

vocab_parallel_logits's shape is [seq_len, batch_size, vocab_size / tp]. if vocab_size is very large like Llama3, use inplace subtract to reduce memory usage.