pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.29k stars 115 forks source link

`freqs_cis` in llama model should be a non-persistent buffer #316

Open tianyu-l opened 1 month ago

tianyu-l commented 1 month ago

Currently it is registered as a persistent buffer, because of two reasons, copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py#L355

# TODO persistent should be set to false, since this buffer can be recomputed.
# however, we set it to true for 2 reasons.  (1) due to pytorch/pytorch#123411,
# compile or pipeline-tracer will not correctly handle non-persistent buffers,
# so we need to fix that.  (2) if we initialize pipeline-parallel models from
# a seed checkpoint rather than calling init_weights, we need freqs_cis to be
# initialized by the checkpoint, or we need to add a separate initializer for
# just the non-persistent buffers that is called after loading checkpoints.

This issue is to track the progress on it. If (1) is fixed, and (2) seems the best solution, we can close this issue.