When set --recompute_granularity full for finetuning, we will see traceback like this:
File "/workspace/Megatron-LLM/megatron/model/transformer.py", line 757, in forward
attention_output, attention_bias = self.self_attention(layernorm_output,
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/Megatron-LLM/megatron/model/transformer.py", line 502, in forward
query_layer, key_layer = apply_rotary_emb(query_layer, key_layer, self.freqs_cis, position_ids=position_ids)
File "/workspace/Megatron-LLM/megatron/model/positional_embeddings.py", line 36, in apply_rotary_emb
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
File "/workspace/Megatron-LLM/megatron/model/positional_embeddings.py", line 19, in reshape_for_broadcast
assert freqs_cis.shape == (x.shape[0], x.shape[-1])
AssertionError
When tracing back, we find that reshape_for_broadcast is only called when position_ids is None, which means the position_ids was NOT passed to each transformer layer when --recompute_granularity full (finetuning did work when --recompute_granularity selective).
I further chased the error down to megatron/model/transformer.py, it turns out there are some missing arguments when calling the checkpoint function through _checkpointed_forward, which I fixed in this PR.
When set
--recompute_granularity full
for finetuning, we will see traceback like this:When tracing back, we find that reshape_for_broadcast is only called when position_ids is None, which means the position_ids was NOT passed to each transformer layer when
--recompute_granularity full
(finetuning did work when--recompute_granularity selective
).I further chased the error down to
megatron/model/transformer.py
, it turns out there are some missing arguments when calling the checkpoint function through_checkpointed_forward
, which I fixed in this PR.