BlackSamorez / tensor_parallel

Automatically split your PyTorch models on multiple GPUs for training & inference
MIT License
625 stars 40 forks source link
deep-learning machine-learning natural-language-processing nlp python pytorch pytorch-transformers

tensor_parallel

PyPI version Black CI status

πŸš€  Try new 40B LLMs demo in Kaggle

Run large PyTorch models on multiple GPUs in one line of code with potentially linear speedup.

import transformers
import tensor_parallel as tp
tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-13b")
model = transformers.AutoModelForCausalLM.from_pretrained("facebook/opt-13b")  # use opt-125m for testing

model = tp.tensor_parallel(model, ["cuda:0", "cuda:1"])  # <- each GPU has half the weights

inputs = tokenizer("A cat sat", return_tensors="pt")["input_ids"].to("cuda:0")
outputs = model.generate(inputs, num_beams=5)
print(tokenizer.decode(outputs[0])) # A cat sat on my lap for a few minutes ...

model(input_ids=inputs, labels=inputs).loss.backward()  # training works as usual

Installation

Latest stable version (recommended):

pip install tensor_parallel

Bleeding edge version:

pip install https://github.com/BlackSamorez/tensor_parallel/archive/main.zip

Usage

Simply wrap your PyTorch model with tp.tensor_parallel and use it normally. For best memory efficiency, call tp.tensor_parallel while the model is still on CPU.

Here are a few use cases:

Advanced parameters to tensor_parallel:

Saving the model

To save a model such that it could be used in a non tensor_parallel context, you should use a save_tensor_parallel context wrapper.

import torch
import transformers
import tensor_parallel as tp

model = tp.tensor_parallel(
    transformers.AutoModelForCausalLM.from_pretrained("facebook/opt-13b"), 
)

# A whole lot of trainig...

with tp.save_tensor_parallel(model):
    torch.save(model.state_dict(), "/tmp/")
    # or 
    model.save_pretrained("/tmp/")

Such code saves a model as if it was never split. It works by gathering model parts during state_dict creation.

Memory efficient dispatch

Normally, to normally create and dispatch a tensor_parallel model, one needs the whole model in memory. This can be troublesome, but there is another way.

It's possible to convert a state_dict of a basic model into the corresponding tensor_parallel state_dict using a helper function convert_state_dict. The state dict can then be dispatched and loaded into the model:

import accelerate
import transformers

import tensor_parallel as tp

# Initialize a weightless tensor_parallel model from MyModel
with accelerate.init_empty_weights():
    model = tp.TensorParallel(
        MyModel(),
        device_ids=[0, 1] # and prepare it to be put on GPUs 0 and 1
    )

# Load partial state_dict for MyModel
state_dict = torch.load("my_model_part_1_of_5.bin")

# Convert it into a tensor_parallel state_dict
tensor_parallel_state_dict = tp.convert_state_dict(
    state_dict,
    tensor_parallel_config=model.tensor_parallel_config,
    world_size=len(model.devices),
)

# Dispatch the partial state_dict (load_state_dict doesn't work with meta so here I use accelerate)
device_map = tp.infer_sharded_device_map(model)
for param_name, param in state_dict.items():
    module_name = param_name
    while len(module_name) > 0 and module_name not in device_map:
        module_name = ".".join(module_name.split(".")[:-1])
    param_device = device_map[module_name]
    accelerate.utils.set_module_tensor_to_device(model, param_name, param_device, value=param)

With this no more than one part of the model needs to be loaded into memory at once.

FAQ

Why use tensor_parallel ...

In short, use tensor_parallel for quick prototyping on a single machine. Use DeepSpeed+Megatron or alpa for million-dollar training runs.

Troubleshooting

If you experience NCCL errors, or random hanging, you may have some code errors that are not displayed properly. To debug these errors, we recommend restarting with export TENSOR_PARALLEL_USE_NATIVE=1 or on a single device.

If you found a bug or encountered a problem, please report it to our issue tracker. We will do our best to help, but it may take some time before we get to it. Please create issues only if your problem is specifically with tensor_parallel. For example, if you need help installing transformers or optimizing your code, please seek it elsewhere.

Code style

We use black and isort for all pull requests. Before committing your code, simply run black . && isort . and you will be fine.