BlackSamorez / tensor_parallel

Automatically split your PyTorch models on multiple GPUs for training & inference
MIT License
629 stars 39 forks source link

Meta devices support #54

Closed BlackSamorez closed 1 year ago

BlackSamorez commented 1 year ago

Support for meta devices for deployment with little RAM.

BlackSamorez commented 1 year ago

Memory efficient dispatch

To normally create and dispatch a tensor_parallel model one need whole model in memory. This can be troublesome6 but there is another way.

It's possible to create a tensor_parallel on a machine with enought RAM, save it preserving distributed state and then reuse the save files to dispach model shards straight to GPUs on any other machine.

The code to save distributed state should look like this:

import transformers
import tensor_parallel as tp
RAM_LIMIT = "6GB" # we want to deploy the model on machines with as little as 6GB of RAM
model = tp.TensorParallelPreTrainedModel(
    transformers.AutoModelForCausalLM.from_pretrained("facebook/opt-13b"), 
    device_ids=["cpu", "cpu"] # split model but load into RAM
)
model.save_pretrained("opt-13b-tensor-parallel", max_shard_size=RAM_USAGE_LIMIT) # save model's distributed state

It normally produces many files containing model weights as well as model index. Those files then can be put on a machine for training/inference and loaded as follows:

import transformers
import tensor_parallel as tp
from accelerate import init_empty_weights, load_checkpoint_in_model
# Initialize a weightless model
with init_empty_weights():
    model = tp.TensorParallelPreTrainedModel(
        transformers.AutoModelForCausalLM.from_config(
            AutoConfig.from_pretrained("facebook/opt-13b")
        ),
        device_ids=[0, 1] # and prepare it to be put on GPUs 0 and 1
    )
device_map = tp.infer_sharded_device_map(model_tp) # assign parameter to devices
# Incrementally load the weights using your favorite loader
load_checkpoint_in_model(
    model,
    checkpoint="opt-13b-tensor-parallel/pytorch_model.bin.index.json",
    device_map=device_map,
)

Max RAM consumption of such loading is max_shard_size which in this example was set to 6GB.