cli99 / llm-analysis

Latency and Memory Analysis of Transformer Models for Training and Inference
Apache License 2.0
356 stars 42 forks source link

[REQUEST]some question about memory and latency analysis #27

Open liu-yx17 opened 1 month ago

liu-yx17 commented 1 month ago

您的开源项目llm-analysis帮助了我很多,但我尚有一些疑问,烦请您拨冗解答。 1.在analysis.py中,"get_memory_optimizer_state_and_gradient_per_layer"函数中,为了得到"memory_optimizer_state_others_per_layer","self.get_num_params_per_layer_layernorm()" 除以了“self.parallelism_config.tp_size”,张量并行tp会对LN层的优化器状态和梯度进行切分吗? 2.在analysis.py的“get_latency_fwd_per_tp_comm”中,这里没有考虑节点间通信传输的效率,但其他地方均考虑了,这是为什么呢? 3.在analysis.py的“get_latency_fwd_per_layer_shared_dp_comm“中,如果dp_size<=8,则使用节点间通信,否则则使用节点内通信,这块是否有误?dp_size与通信选择似乎没有关系? 还请您不吝赐教! 祝好!

cli99 commented 2 weeks ago

谢谢指出的问题。请看下https://github.com/cli99/llm-analysis/pull/28 的改动是不是合理。

liu-yx17 commented 2 weeks ago

谢谢指出的问题。请看下#28 的改动是不是合理。

感谢您的回复,我认为是合理的。还有一些困惑,如果您有时间烦请解惑。 1.在get_activation_memory_per_layer_attn中,使用了flash_attn以后,memory_attn_compute = (2 seq_len batch_size hidden_dim + 4 n_head seq_len batch_size) * bytes_per_activation / tp_size;这个分别代表什么意义呢?如果按照flashattention-v1论文中的分析,似乎只需要存储按行求和以及按行取最值的激活就可以了。 2.fwd_prefetch和bwd_prefetch的含义我也不是很理解,这两部分的预填充该如何理解? 3.element-wise算子的flpos,本项目中是忽略了的是吗? 请不吝赐教!谢谢!

cli99 commented 1 week ago

@liu-yx17

  1. 这个是个大致的估算,measurement-based。可能会跟 flash_attn目前实际的实现有出入。但最关键的一致是memory usage scales linearly with sequence length.
  2. 请看下https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-buffers-sizes。 目前是在memory usage 中加入了对FSDP buffers sizes的一个lower-bound 估算 (2 times unsharded transformer block parameters buffer)。
  3. 是的。