Open Andy666G opened 2 weeks ago
vocab_parallel_logits's shape is [seq_len, batch_size, vocab_size / tp]. if vocab_size is very large like Llama3, use inplace subtract to reduce memory usage.
vocab_parallel_logits's shape is [seq_len, batch_size, vocab_size / tp]. if vocab_size is very large like Llama3, use inplace subtract to reduce memory usage.