This PR drafts the implementation for deferred initialisation.
Currently, we experience the limits of CPU RAM on some machines leading to RAM overflow with larger models. As a solution, pytorch (including 2.4) proposes to use torchdistx.deferred_init which only initialises the model on a fake device and records all operations performed on the respective tensors. After sharding the model, the operations are then replayed on the newly instantiated tensors. Thus, the model is never materialized on the CPU, fixing the RAM overflow issue.
In contrast, the FSDP2.0 release documentation proposed to use meta device initialisation. However, it is not immediately clear, whether this also reduces the CPU RAM footprint.
FSDP2 supports a new meta-device initialization flow that does not require materializing a module on GPU before sharding it, removing the need for param_init_fn . See Meta-Device Initialization for more details."
Since the version issue with torchdistx, I suggest we look more into FSDP2.0 and if they support some sort of deferred initialisation we directly upgrade to FSDP 2.
General Changes
refactored hierarchical instantiation into multiple modules
implemented DeferredInitModelBuilder which builds the deferred modules. Note that the implementation asserts (and throws an error if violated) that the model config is a single model (i.e., no top-level references such as model -> model_raw allowed). The reason is that the deferred init must record the tensor operations to later replay them, which in the case of references would be out-of-scope.
With ComponentBuilderIF we now have a generic interface, which allows to register custom component building strategies in the HierarchicalInstantiation
Breaking Changes
..
Checklist before submitting final PR
[ ] My PR is minimal and addresses one issue in isolation
[ ] I have merged the latest version of the target branch into this feature branch
[ ] I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
[ ] I have run a sample config for model training
[ ] I have checked that all tests run through (python tests/tests.py)
What does this PR do?
This PR drafts the implementation for deferred initialisation. Currently, we experience the limits of CPU RAM on some machines leading to RAM overflow with larger models. As a solution, pytorch (including 2.4) proposes to use torchdistx.deferred_init which only initialises the model on a fake device and records all operations performed on the respective tensors. After sharding the model, the operations are then replayed on the newly instantiated tensors. Thus, the model is never materialized on the CPU, fixing the RAM overflow issue.
Even though the official Pytorch 2.4 documentation proposes to use torchdistx in conjunction with FSDP, the torchdistx repo seems to be not actively maintained and the package is only available for a limited number of CUDA and python versions: https://github.com/pytorch/torchdistx?tab=readme-ov-file#dependencies
In contrast, the FSDP2.0 release documentation proposed to use meta device initialisation. However, it is not immediately clear, whether this also reduces the CPU RAM footprint.
Since the version issue with torchdistx, I suggest we look more into FSDP2.0 and if they support some sort of deferred initialisation we directly upgrade to FSDP 2.
General Changes
DeferredInitModelBuilder
which builds the deferred modules. Note that the implementation asserts (and throws an error if violated) that the model config is a single model (i.e., no top-level references such as model -> model_raw allowed). The reason is that the deferred init must record the tensor operations to later replay them, which in the case of references would be out-of-scope.ComponentBuilderIF
we now have a generic interface, which allows to register custom component building strategies in theHierarchicalInstantiation
Breaking Changes
Checklist before submitting final PR
python tests/tests.py
)