Open HamidShojanazeri opened 2 years ago
Is there also any proposed standardization of APIs around partial checkpoints?
Is there also any proposed standardization of APIs around partial checkpoints?
@yifuwang @ananthsub @HamidShojanazeri can comment more
Authors : Hamid Shojanazeri, Shen Li
Problem statement
Currently, Torchserve does not have a general solution for serving large models for inference.The only available support is in HuggingFace(HF) example for serving GPT2 style model using parallelize feature from HF. Serving large models uses model parallel solutions; a model is partitioned and placed on different devices and inference input will pass through these partitions until complete the forward pass. You can read more about it here.
Asking users to manually handle model partitioning imposes a lot of complexity to users and complicates the Torchserve backend/frontend configs in terms of number of workers and gpu assignments. Ideally we need an auto partitioner to take a checkpoint from the user, partition it and place on available devices.
User API
The user API would be similar to the following, it provides the flexibility to be used with custom handlers or integrate to base_handler.
When would an Auto partitioner be available?
Pytorch Distributed team is working on this feature,the ETA would be around July 2022.
Available solutions
So far most of the serving solutions for the large scale models are very much tailored for a specific architecture, which means to partition a model into multiple devices the full architecture of the models need to be known beforehand. Examples of these solutions can be found in Triton for Faster Transformer models, HuggingFace for GPT2 and T5.
Also, approaches like DeepSpeed suggest users providing partial checkpoints and it takes care of its placements on multiple devices.
The manual model parallelization with loading partial checkpoints or parallelizing a specific model through its known configs can be done in Torchserve custom handler already as we will show later in this doc.
Missing solution
Auto partitioner is not available in Pytorch yet. Ideally, we need a solution where users bring an arbitrary checkpoint and the serving solution can load the model, automatically partition it and place it over different devices.
The auto partitioner has been supported in Sagemaker model parallel library, however it has been in use for training purpose and inference solution yet to be available. This feature does not exist in Pytorch core ATM. Not having visibility into sub-modules in the user-defined model is the main issue.
There are two potential solutions to this problem
Pytorch distributed team is working to support this feature for Torchserve.
How does the Sagemaker model parallel handle auto partitioning?
During the first training step, the model parallel library internally runs a tracing step that is meant to construct the model graph and determine the tensor and parameter shapes. After this tracing step, the library constructs a tree, which consists of the nested nn.Module objects in the model, as well as additional data gathered from tracing, such as the amount of stored nn.Parameters, and execution time for each nn.Module.
Next, the library traverses this tree from the root and runs a partitioning algorithm that assigns each nn.Module to a device, which balances computational load (measured by module execution time) and memory use (measured by the total stored nn.Parameter size and activations). If multiple nn.Modules share the same nn.Parameter, then these modules are placed on the same device to avoid maintaining multiple versions of the same parameter. Once the partitioning decision is made, the assigned modules and weights are loaded to their devices.
How does Torchserve work?
Torchserve is a model serving library that uses REST APIs to handle connection between client and server. Frontend has been implemented using Java and backend is in Python.
Backend is responsible for initializing/ loading the model, running the inference and preparing the response.
Device assignment :
Possible scenarios for large model inference in Torchserve
Here, we will list all the scenarios for serving large models on Torchserve. With large model inference, we are targeting models that would not fit into one gpu. Hence Model Parallel paradigm would be what we are looking for in this context.
Pros
Cons
Pros
Cons
Pros
Cons
Desired general solution
Auto model partitioner that support arbitrary checkpoints, from discussions with
Extracting the structure of the model is not easy from checkpoints it needs 1) load the model 2) use torch.fx/TorchScript to do tracing 3) partition into stages 4) move params to the target device and insert D2D comm ops at stage boundaries accordingly.
Performance considerations
Open questions
Acknowledgements
We would like to thank @msaroufim , @chauhang, @lxning, @jamesr66a ,@pbelevich, @kwen2501 and @cbalioglu for their great support, insightful inputs and comments in authoring this RFC.