Closed ad8e closed 2 months ago
I am not sure I follow. self.init_weights
does not allocate any tensors except for the freqs_cis
buffer, so whether we allocate it or not does not really affect the memory usage.
You are right; calling init_weights
here will not affect the memory usage. It would only add some time overhead if loading a checkpoint.
We do not have the code example here, but I would recommend a meta-device init flow for loading too. FSDP2 can support meta-device init + load_state_dict(assign=True)
.
Interesting, so the model is only ever created on GPU/CPU if not loading from checkpoint. So there are no issues after all. Thanks for the info.
https://github.com/pytorch/torchtitan/blob/4e5ffafb6e1ebc159ca57625c875d0d44e5a654a/torchtitan/models/llama/model.py#L374
Luckily, this line is never activated, since the model is always constructed on the
meta
device in torchtitan, which makes this a no-op.In the general case, if the user tries to init the model on GPU or CPU, this line runs before the TP and FSDP sharding. So it will OOM a moderately-sized model. If loading a checkpoint from disk, it will add overhead.
Deleting the line should be fine and cause no issues.