epfLLM / Megatron-LLM

distributed trainer for LLMs
Other
544 stars 77 forks source link

Passed position_ids are ignored for `PositionEmbeddingType.rotary` #23

Closed andreaskoepf closed 1 year ago

andreaskoepf commented 1 year ago

The current rotary-embedding code in this repo seems to ignore position_ids and instead always assumes position_ids to match the sequence indices, i.e. position_ids are not passed on to the encoder.forward() function: https://github.com/epfLLM/Megatron-LLM/blob/90061187a357e7960707b26a848df8003fc7c32d/megatron/model/language_model.py#L512-L515

The actual call to appy_rotary_emb() always uses pre-computed self.freqs_cis without any position_id dependent lookups, see transformer.py#L499.

This is incompatible to the args.reset_position_ids parameter used for get_ltor_masks_and_position_ids() in finetune.py#L76 which if set to True generates position-ids for batch-packing (potentially restarting from 0 multiple times in one sequence).

Compare to HF transformers Llama impl which passes the position-ids on to the attention layers: src/transformers/models/llama/modeling_llama.py#L412-L420 and actually uses them in apply_rotary_pos_emb().

andreaskoepf commented 1 year ago

Just looked at how RoPE was integrated by NVIDIA and it also seems not to use the position_ids, see gpt_model.py#L155C35-L155C49. The RotaryEmbedding class supports an offset parameter but not per token position ids, see rotary_pos_embedding.py#L18-L28.