Modalities / modalities

A framework for training multimodal foundation models.
MIT License
57 stars 5 forks source link

Feat/deferred init #197

Closed le1nux closed 1 month ago

le1nux commented 1 month ago

What does this PR do?

This PR drafts the implementation for deferred initialisation. Currently, we experience the limits of CPU RAM on some machines leading to RAM overflow with larger models. As a solution, pytorch (including 2.4) proposes to use torchdistx.deferred_init which only initialises the model on a fake device and records all operations performed on the respective tensors. After sharding the model, the operations are then replayed on the newly instantiated tensors. Thus, the model is never materialized on the CPU, fixing the RAM overflow issue.

Even though the official Pytorch 2.4 documentation proposes to use torchdistx in conjunction with FSDP, the torchdistx repo seems to be not actively maintained and the package is only available for a limited number of CUDA and python versions: https://github.com/pytorch/torchdistx?tab=readme-ov-file#dependencies

In contrast, the FSDP2.0 release documentation proposed to use meta device initialisation. However, it is not immediately clear, whether this also reduces the CPU RAM footprint.

FSDP2 supports a new meta-device initialization flow that does not require materializing a module on GPU before sharding it, removing the need for param_init_fn . See Meta-Device Initialization for more details."

Since the version issue with torchdistx, I suggest we look more into FSDP2.0 and if they support some sort of deferred initialisation we directly upgrade to FSDP 2.

General Changes

Breaking Changes

Checklist before submitting final PR

le1nux commented 1 month ago

FYI, @flxst @mali-git @fromm-m