pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.29k stars 115 forks source link

Probably shouldn't call `init_weights` in constructor of the model #290

Closed ad8e closed 2 months ago

ad8e commented 2 months ago

https://github.com/pytorch/torchtitan/blob/4e5ffafb6e1ebc159ca57625c875d0d44e5a654a/torchtitan/models/llama/model.py#L374

Luckily, this line is never activated, since the model is always constructed on the meta device in torchtitan, which makes this a no-op.

In the general case, if the user tries to init the model on GPU or CPU, this line runs before the TP and FSDP sharding. So it will OOM a moderately-sized model. If loading a checkpoint from disk, it will add overhead.

Deleting the line should be fine and cause no issues.

awgu commented 2 months ago

I am not sure I follow. self.init_weights does not allocate any tensors except for the freqs_cis buffer, so whether we allocate it or not does not really affect the memory usage.

ad8e commented 2 months ago

You are right; calling init_weights here will not affect the memory usage. It would only add some time overhead if loading a checkpoint.

awgu commented 2 months ago

We do not have the code example here, but I would recommend a meta-device init flow for loading too. FSDP2 can support meta-device init + load_state_dict(assign=True).

ad8e commented 2 months ago

Interesting, so the model is only ever created on GPU/CPU if not loading from checkpoint. So there are no issues after all. Thanks for the info.