Alpha-VLLM / LLaMA2-Accessory

An Open-source Toolkit for LLM Development
https://llama2-accessory.readthedocs.io/
Other
2.68k stars 170 forks source link

Support lazy model init #60

Open linziyi96 opened 1 year ago

linziyi96 commented 1 year ago

This PR aims to add the support of lazy model initialization. This is one of the two steps to lower the CPU memory usage for quantized models. Quantization is currently implemented by replacing regular linear layers with quantized linear layers. Without lazy init, the full-precision model before replacement results in a huge peak memory usage, making both training and inference hard to run on commodity hardware even with aggressive quantization: For example, the 4-bit 13B, which theoretically only needs 6.5GB of memory and fits comfortably in any mainstream PC, now requires 52GB of memory (full precision model and full precision checkpoint); and the 4-bit 70B model, which theoretically needs 35GB of memory and fits in two 3090s, now requires 280GB of memory which is only possible on some expensive HEDT and server platforms.

With lazy init, the model creation steps become: (1) create a placeholder model without allocating any actual storage, (2) replace layers with quantized ones and (3) instantiate all tensors. In this way, we need not manually re-implement a quantized version for each (current or future) model, and only the amount of storage after quantization is actually allocated.

However, supporting lazy init turns out to be a complicated task, as PyTorch essentially provides no good way to decouple model creation and weight initialization at this moment. Despite that tensors can be created as meta, there seems to be no reliable way to initialize them afterwards: The fairscale layers tend to initialize the weights in __init__ and simply do not provide a separate method to initialize the weights after creation; and even if most PyTorch built-in layers do provide reset_parameter methods as of v2.0.1, they usually do not support custom initialization (e.g., LoRA needs zero init, but torch.nn.Linear.reset_parameters always initializes the weights randomly following a uniform distribution).

Facing such a dilemma, I am trying to follow the lazy init implementation of PyTorch FSDP: Relying on the module's reset_parameter method for each module containing directly managed parameters and buffers, with the heavy lifting left to implementing the reset_parameter for each module we used but do not have a working one in all cases.

The model creation process is supposed to be like the following after the change:

# All weights on meta device, including quantized layers but except vision encoder.
# Quantization layer replacement happens in MetaModel.
with default_tensor_type(..., meta=True):
    model = MetaModel(...)

# All tensors in checkpoints are materialized. If a quantized layer sees full-precision states, quantize before materialize.
utils.tensor_parallel.load_tensor_parallel_model_list(...)

# Materialize remaining weights (unseen in loaded checkpoints, using the reset_parameter method).
model.materialize()

Following this plan, the proposed code change is roughly organized into the following parts:

This PR is going to involve an extensive code refactor and need thorough testings so mark it as draft for now.