Open XinDongol opened 4 months ago
The advantage of meta-device init is that it is as fast as possible: the sharded parameters are directly initialized on GPU.
Any other flow requires something more, e.g. (1) initializing unsharded parameters on GPU and then sharding or (2) initializing unsharded parameters on CPU, copying to GPU, and then the sharding. For (1), you need to insert your sharding call to be inline to your module construction, or else you will use too much GPU memory (e.g. you must construct one transformer block on GPU, shard it, construct the next transformer block, shard it, etc.). For (2), initializing parameters on CPU is slow (if you run initialization kernels), and the largest model that you can support is bottlenecked by CPU RAM size.
In some sense, the current model.init_weights()
meta-device approach is a compromise, where we require the user to define this method to initialize all model parameters/buffers but in turn, the initialization is as fast as possible. So to answer your question, if efficiency is not an issue (and CPU RAM size is not a bottleneck), then yes, you could remove the with torch.device("meta")
(and instead do either (1) or (2).)
Thanks for clarfying. Really helpful!
If my understading is correct and CPU RAM is NOT a problem and if I want to do "(2)",
model = model_cls.from_model_args(model_config)
model.init_weights()
model = fully_shard(model, **fsdp_config)
model.to(device="cuda")
### training loop ###
Is this a correct way to do "(2)" ? @awgu
@XinDongol A few clarifications:
model = model_cls.from_model_args(model_config)
# (1) If the `model_cls.__init__` did not already call `init_weights()` or similar
model.init_weights()
# (2) Apply FSDP with multiple FSDP calls, e.g. on each transformer block
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module, **fsdp_config)
fully_shard(model, **fsdp_config) # always call on root
(3) Do not need to move to cuda explicitly
Regarding (1), the torchtitan Llama definition already calls init_weights()
in the Transformer.__init__()
, so there should be no need to call it again separately via model.init_weights()
if we are doing CPU init.
https://github.com/pytorch/torchtitan/blob/f72a2a0da0bdfc394faaab9b3c0f35d0b6f5be50/torchtitan/models/llama/model.py#L373
Regarding (2), in case you were not already aware of the FSDP design, you should apply fully_shard
to some submodules (generally transformer blocks for transformer architecture) in addition to the root module to achieve communication/computation overlap and to avoid peaking memory too much. Concretely, calling fully_shard(module)
constructs one parameter group communicated together (e.g. all-gather parameters, reduce-scatter gradients) from module.parameters()
, excluding those assigned to a nested fully_shard(submodule)
.
Regarding (3), each time you call fully_shard(module)
, the managed parameters/buffers will be moved to the mesh's corresponding device, and in our case, mesh.device_type == "cuda"
. This means we do not need to explicitly call model.to(device="cuda")
.
I was wondering if you could explain more why you want to do CPU init. You may notice the init time is quite long, especially for larger models.
@awgu Thanks for your clarifications!
Initalization on meta tensor is painful. For some architectures (e.g., Mamba, etc), there are some parameters needing complicated pre-computed initaliztion. Assigining values to meta/DTensor tensors is chanlleging. As a result, I would like to initalize a model before distributing it.
I just noticed that there is a new added flag called create_seed_checkpoint
. If my understanding is correct, it can be utilzed to avoid initalization on meta/DTensor tensors.
I would really appreciate some pointers to the complicated initialization to learn more about it.
And yes, I think that the seed checkpoint can be used to avoid the meta device init and instead try to init on CPU.
Here are some examples.
Initalizing them on DTensor seems chanllenging.
cc: @wanchaol @tianyu-l The above two pointers are good examples of real-model init methods that do not fit our current meta-device init flow. As far as I can tell, both would require some custom logic to use DTensor
APIs to make it work (leading to some if <using distributed>: ... else: ...
).
I noticed that there are two parts of implementation that are related to model initialization.
Instancing the model with meta tensor
https://github.com/pytorch/torchtitan/blob/f72a2a0da0bdfc394faaab9b3c0f35d0b6f5be50/train.py#L177-L181
Doing explicit model initalization
https://github.com/pytorch/torchtitan/blob/f72a2a0da0bdfc394faaab9b3c0f35d0b6f5be50/train.py#L209-L210
The issue is that if we do any weight initalization when instancing the module, it will ineffective becuase of the
meta tensor
. As a result, we have to do all initalization explicitly in themodel.init_weights()
.My question is why we want to instance model with
meta tensor
? If effencicy is not an issue, can we simply remove thewith torch.device("meta"):