epfLLM / Megatron-LLM

distributed trainer for LLMs
Other
504 stars 73 forks source link

Fix missing position_ids argument when recompute_granularity == full #86

Open xingyaoww opened 9 months ago

xingyaoww commented 9 months ago

When set --recompute_granularity full for finetuning, we will see traceback like this:

  File "/workspace/Megatron-LLM/megatron/model/transformer.py", line 757, in forward
    attention_output, attention_bias = self.self_attention(layernorm_output,
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Megatron-LLM/megatron/model/transformer.py", line 502, in forward
    query_layer, key_layer = apply_rotary_emb(query_layer, key_layer, self.freqs_cis, position_ids=position_ids)
  File "/workspace/Megatron-LLM/megatron/model/positional_embeddings.py", line 36, in apply_rotary_emb
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
  File "/workspace/Megatron-LLM/megatron/model/positional_embeddings.py", line 19, in reshape_for_broadcast
    assert freqs_cis.shape == (x.shape[0], x.shape[-1])
AssertionError

When tracing back, we find that reshape_for_broadcast is only called when position_ids is None, which means the position_ids was NOT passed to each transformer layer when --recompute_granularity full (finetuning did work when --recompute_granularity selective).

I further chased the error down to megatron/model/transformer.py, it turns out there are some missing arguments when calling the checkpoint function through _checkpointed_forward, which I fixed in this PR.