pytorch / serve

Serve, optimize and scale PyTorch models in production
https://pytorch.org/serve/
Apache License 2.0
4.23k stars 864 forks source link

[RFC]: Torchserve Large Model Inference #1579

Open HamidShojanazeri opened 2 years ago

HamidShojanazeri commented 2 years ago

Authors : Hamid Shojanazeri, Shen Li

Problem statement

Currently, Torchserve does not have a general solution for serving large models for inference.The only available support is in HuggingFace(HF) example for serving GPT2 style model using parallelize feature from HF. Serving large models uses model parallel solutions; a model is partitioned and placed on different devices and inference input will pass through these partitions until complete the forward pass. You can read more about it here.

Asking users to manually handle model partitioning imposes a lot of complexity to users and complicates the Torchserve backend/frontend configs in terms of number of workers and gpu assignments. Ideally we need an auto partitioner to take a checkpoint from the user, partition it and place on available devices.

User API

The user API would be similar to the following, it provides the flexibility to be used with custom handlers or integrate to base_handler.

#in config.properties
parallelize = true
number_of_gpu = x

# in  handler
from XXX import auto_partitioner

model = auto_partitioner("model_chkpoints", number_of_gpus=x)

When would an Auto partitioner be available?

Pytorch Distributed team is working on this feature,the ETA would be around July 2022.

Available solutions

So far most of the serving solutions for the large scale models are very much tailored for a specific architecture, which means to partition a model into multiple devices the full architecture of the models need to be known beforehand. Examples of these solutions can be found in Triton for Faster Transformer models, HuggingFace for GPT2 and T5.

Also, approaches like DeepSpeed suggest users providing partial checkpoints and it takes care of its placements on multiple devices.

The manual model parallelization with loading partial checkpoints or parallelizing a specific model through its known configs can be done in Torchserve custom handler already as we will show later in this doc.

Missing solution

Auto partitioner is not available in Pytorch yet. Ideally, we need a solution where users bring an arbitrary checkpoint and the serving solution can load the model, automatically partition it and place it over different devices.

The auto partitioner has been supported in Sagemaker model parallel library, however it has been in use for training purpose and inference solution yet to be available. This feature does not exist in Pytorch core ATM. Not having visibility into sub-modules in the user-defined model is the main issue.

There are two potential solutions to this problem

Pytorch distributed team is working to support this feature for Torchserve.

How does the Sagemaker model parallel handle auto partitioning?

During the first training step, the model parallel library internally runs a tracing step that is meant to construct the model graph and determine the tensor and parameter shapes. After this tracing step, the library constructs a tree, which consists of the nested nn.Module objects in the model, as well as additional data gathered from tracing, such as the amount of stored nn.Parameters, and execution time for each nn.Module.

Next, the library traverses this tree from the root and runs a partitioning algorithm that assigns each nn.Module to a device, which balances computational load (measured by module execution time) and memory use (measured by the total stored nn.Parameter size and activations). If multiple nn.Modules share the same nn.Parameter, then these modules are placed on the same device to avoid maintaining multiple versions of the same parameter. Once the partitioning decision is made, the assigned modules and weights are loaded to their devices.

How does Torchserve work?

Torchserve is a model serving library that uses REST APIs to handle connection between client and server. Frontend has been implemented using Java and backend is in Python.

Backend is responsible for initializing/ loading the model, running the inference and preparing the response.

Device assignment :

Possible scenarios for large model inference in Torchserve

Here, we will list all the scenarios for serving large models on Torchserve. With large model inference, we are targeting models that would not fit into one gpu. Hence Model Parallel paradigm would be what we are looking for in this context.

  1. Load partial checkpoints on the available devices (similar to [DeepSpeed](https://github.com/microsoft/DeepSpeed/blob/b6f0ac97ae03e8bc71f75991eb4a8a7f28d1fd9b/deepspeed/inference/engine.py#L36) , [load_state_dict](https://github.com/microsoft/DeepSpeed/blob/b6f0ac97ae03e8bc71f75991eb4a8a7f28d1fd9b/deepspeed/runtime/state_dict_factory.py#L117)):
def load_partial_checkpoints(ckpt_list, world_size):
   num_ckpt = len(ckpt_list)
   assert world_size % num_ckpt == 0, 'Invalid checkpoints and world size for sd split'
   model = torch.load(ckpt_list[ckpt_index],
                   map_location= cuda:world_size[index])

Pros

Cons

  1. Model specific support (similar to Triton, HF), where model parallelism is implemented using the known architecture of the model
from transformers import  AutoModelForCausalLM
#known configs for the model such as HF model
model = AutoModelForCausalLM.from_pretrained(model_dir)
model.parallelize()

Pros

Cons

  1. Auto model partitioner , that would take a full checkpoint and shard it (auto paritoner) and load it on assigned devices.
# in config.properties
parallelize = true
number_of_gpu = x
# in handler 
from XXX import auto_partitioner
model = auto_partitioner("model_chkpoints", number_of_gpus=x)

Pros

Cons

Desired general solution

Auto model partitioner that support arbitrary checkpoints, from discussions with

Performance considerations

Open questions

  1. Should Torchserve support multi-model serving when model parallel inference is in use? – it might be still possible if not using all gpus for parallelizing inference.
  2. How batch scheduling should be done/ pipeline parallelism is necessary? ( measure performance hits) — it can be useful if dealing with a large batch of inputs, which might be more suitable for non real time applications.

Acknowledgements

We would like to thank @msaroufim , @chauhang, @lxning, @jamesr66a ,@pbelevich, @kwen2501 and @cbalioglu for their great support, insightful inputs and comments in authoring this RFC.

rahul003 commented 2 years ago

Is there also any proposed standardization of APIs around partial checkpoints?

msaroufim commented 2 years ago

Is there also any proposed standardization of APIs around partial checkpoints?

@yifuwang @ananthsub @HamidShojanazeri can comment more