pytorch / serve

Serve, optimize and scale PyTorch models in production
https://pytorch.org/serve/
Apache License 2.0
4.22k stars 864 forks source link

[RFC] TorchServe Open Platform for Large Distributed Model Inference #2188

Closed lxning closed 1 year ago

lxning commented 1 year ago

🚀 The feature

TorchServe Open Platform for Large Distributed Model Inference

Authors: Li Ning, Hamid Shojanazeri, Ke Wen

As the size of machine learning models grows, distributed inference becomes increasingly necessary. To address this need, previously we published an RFC calling for a PyTorch native solution. This has been developed as PiPPy (Pipeline Parallelism for PyTorch), please read more using the example. This will be officially released soon. This design document proposes the development of a new open platform for large distributed model inference in TorchServe.This platform will support popular libraries such as PyTorch native PiPPy, Microsoft Deepspeed, and HuggingFace Accelerate, making TorchServe an even more powerful framework for serving PyTorch models.

Goals

The primary goal of this project is to extend the large model inference support in Torchserve to PyTorch native solution and other popular solutions. The platform should provide an easy-to-use interface for users and support existing handler APIs. Other goals of the platform include:

Requirements

Distributed inference support

The platform should support distributed inference using popular libraries such as PyTorch native PiPPy, Microsoft Deepspeed, and HuggingFace Accelerate. This will require changes to the TorchServe core to support these libraries as plugins.

Multi-GPU and cluster support

The platform should provide seamless support for distributed inference across multiple GPU devices and clusters. To achieve this, changes will be required in the TorchServe core to manage multiple devices and clusters.

Existing handler API support

The platform should ensure backward compatibility with existing TorchServe handler APIs, providing users with the option to either utilize the default distributed inference handler or easily customize it to their specific needs.

Flexibility and extensibility

TorchServe fosters an inclusive and welcoming environment that encourages community contribution by providing an open platform for collaboration and development. The platform should be flexible and extensible to allow for integration with other libraries and technologies.

Consistent user experience

TorchServe ensures a consistent user experience for both large distributed model inference and non-distributed model inference. With TorchServe, a single server can handle 1 or more workers for a large distributed model and can load multiple models, including 1 or more large distributed models and 1 or more non-distributed models.

User Experience Description

TorchServe model mar file

torchrun: max_restarts: 3

TS backend parameters

pippy: rpc_timeout: 1800 pp_group_size: 4 # pipeline parallel size, tp_group_size = world size / pp_group_size

TorchServe now uses a priority order to determine the final value of a model's parameters. Specifically, the config.property file has the lowest priority, followed by the model configuration YAML file, and finally, the REST or gRPC model management API has the highest priority.
* Handler
The new platform offers base handlers for popular libraries, such as "pippy_pipeline" for PyTorch pippy pipeline parallel and "pippy_pptp" for PyTorch pippy pipeline+tensor parallel. Users can use these base handlers directly or customize them to suit their specific needs.
*  torch-model-archiver
The process of creating a model archive (MAR) file using torch-model-archiver remains the same as before. However, now users can also provide the model configuration YAML file as inputs to generate the MAR file. For example

torch-model-archiver --model-name bloom --version 1.0 --handler pippy_pipeline.py --extra-files model.zip,setup_config.json -r requirements.txt -c model-config.yaml

ll bloom.mar -rw-rw-r-- 1 ubuntu ubuntu 3952930456 Mar 06 19:35 bloom.mar

cat MAR-INF/MANIFEST.json { "createdOn": "06/03/2023 03:04:14", "runtime": "python", "model": { "modelName": "bloom", "handler": "pippy_pipeline.py", "modelVersion": "1.0", "requirementsFile": "requirements.txt", "configFile": "model-config.yaml" }, "archiverVersion": "0.7.0"

It is possible to seamlessly integrate other popular distributed libraries such as Deepspeed, Accelerate into TorchServe by specifying them in the requirement.txt file, in the same way as the existing TorchServe. TorchServe installs the libraries during model registration. For example

cat requirements.txt transformers accelerate

#### Model management and Inference API
The management and inference API for large distributed models remain unchanged.
### Internal Design
#### TorchServe Model Archiver
The model archiver is enhanced to support a model config YAML file, which enables users to define model parameters for various components including TorchServe, TorchRun, and popular libraries such as PyTorch, pippy, Microsoft Deepspeed, and HuggingFace Accelerate. This allows for greater flexibility in configuring and deploying large distributed models on the platform.
The latest update to the "torch-model-archiver" command now includes support for a model config YAML file, which can be used to specify and set the model's runtime configuration parameters. 

usage: torch-model-archiver [-h] --model-name MODEL_NAME [--serialized-file SERIALIZED_FILE] [--model-file MODEL_FILE] --handler HANDLER [--extra-files EXTRA_FILES] [--runtime {python,python3}] [--export-path EXPORT_PATH] [--archive-format {tgz,no-archive,default}] [-f] -v VERSION [-r REQUIREMENTS_FILE] [-c CONFIG_FILE]

#### TorchServe Frontend
* TorchRun integration
The TorchServe frontend is extended to support torchrun, allowing it to launch a model's worker with multiple RPC jobs in the distribution cluster. TorchServe seamlessly leverages the functionality of Torch Distributed Elastic, including features such as RPC job management, retry, metrics, and log collection, to provide a reliable and scalable distributed inference platform. Sample execution command in log

Worker cmdline: [torchrun, --nnodes=1, --nproc_per_node=4, --max_restarts=3, --log_dir=/tmp/torchelastic_ts, --rdzv_backend=c10d, --rdzv_endpoint=localhost:29500, --rdzv_id=bloom_29500, /opt/conda/envs/py38/lib/python3.8/site-packages/ts/model_service_worker.py, --sock-type, unix, --sock-name, /tmp/.ts.sock.29500, --metrics-config, /opt/conda/envs/py38/lib/python3.8/site-packages/ts/configs/metrics.yaml]

* Multi-nodes Distribution Cluster
TorchServe provides two modes for supporting multi-nodes in a distribution cluster. The static mode is used in SageMaker inference endpoints, where Sagemaker provides a list of nodes as cluster parameters. The model registration and inference request are then broadcasted to each node by SageMaker. TorchServe selects the first node from the list as the master node and port.
Alternatively, the dynamic mode is used in Kubernetes with etcd installation. TorchServe provides a Kubernetes controller that dynamically picks the master node and updates the master node and port in etcd, allowing each worker to retrieve the master node and port information from etcd.
A more detailed explanation of this feature will be addressed in a separate design document.
* GPU assignment
The TorchServe frontend checks the available GPU devices to determine if there are enough devices to accommodate the requested number of workers. The frontend then sends the first GPU device index in a set of GPU devices to the backend. The backend is designed to automatically detect the set of GPU IDs for this node and sets them as the "cuda_visible_devices" for the cluster.
* Communication with backend
Previously, in TorchServe, each workerthread in the frontend was mapped to a single Python worker process in the backend via 1:1 socket communication. However, for large distributed models, this has been extended to allow one frontend workerthread to communicate with a set of RPC jobs in the backend. Each job is assigned a dedicated socket.
The frontend workerthread is responsible for broadcasting the model loading request to all the RPC jobs in the backend and waiting for their response. In the case of pipeline parallelism, the frontend only sends the inference request to the master node of the cluster and waits for the response. However, in the case of tensor parallelism or tensor + pipeline parallelism, the frontend needs to broadcast the model loading request to all the jobs in the cluster and wait for their responses.
* Socket assignment for backend
To simplify the management of ports for distributed jobs, TorchServe introduces a new environment variable, TS_INIT_DISTRIBUTION_PORT. This variable is defined in the config.properties file as 'initial_distribution_port' and specifies the starting port number for distribution jobs. Each worker in a model occupies 'x' ports on a single TorchServe instance, where 'x' is equal to the number of processes per node (nproc-per-node). Each distribution job uses one port. When a backend worker starts up, it automatically calculates its listening port by reading the input socket name and local rank. Specifically, the socket number plus local rank equals the listening port.
* Monitoring and logging
One of the features of the platform is its ability to collect logs and metrics from Torch Elastic and each job within the distribution cluste[y][z][aa]r. This provides valuable information on the performance and behavior of the distributed inference process, allowing for better monitoring and optimization.
By default, torch.distributed.elastic sets the log level to 'warning', but TorchServe changes this setting to 'info' for torch.distributed.elastic. Enabling this setting allows for the collection of all TorchServe backend log information.
#### TorchServe Backend
* Model Config Loading
The TorchServe backend has been improved to support parsing the model configuration YAML file contained in the MAR file. This enhancement enables TorchServe to extract the configuration parameters required for TorchServe, torchrun, and other distributed libraries such as PyTorch pippy, Microsoft Deepspeed, and HuggingFace Accelerate. All of the parameters can be retrieved by accessing the 'model_yaml_config' property within the context. For example, the 'pippy' parameter can be accessed using the syntax 

ctx.model_yaml_config["pippy"]["rpc_timeout"].

* Default Distributed Model Handler
TorchServe provides default handlers for popular libraries, while still maintaining the existing handler API. Users can easily use these handlers directly or customize them according to their specific needs. 
Existing handler APIs:
1. Model Initializer
def initialize(self, context):
2. Inference
def handle(self, data, context):
A new directory "ts/torch_handler/distributed" is created to store all the base handlers for PyTorch native pippy library and other popular libraries. For example:

ts/torch_handler ├── init.py ├── base_handler.py ├── contractions.py ├── densenet_handler.py ├── distributed │ ├── base_accelerate.py │ ├── base_deepspeed.py │ ├── base_pippy_pipeline.py │ └── base_pippy_pptp.py ├── image_classifier.py ├── image_segmenter.py ├── object_detector.py


These handlers are responsible for implementing the function "def initialize(self, context):" which retrieves user-defined parameters (ie. in model config  yaml file) for the libraries, initialize the libraries, and deploy the model to multiple devices (eg. GPU). Users can customize the pre and post processing function in the existing style.

### Motivation, pitch

As the size of machine learning models grows, distributed inference becomes increasingly necessary. To address this need, previously we published an [RFC](https://github.com/pytorch/serve/issues/1579) calling for a PyTorch native solution. This has been developed as [PiPPy](https://github.com/pytorch/tau/tree/main/pippy) (Pipeline Parallelism for PyTorch), please read more using the [exampl](https://github.com/pytorch/tau/tree/main/examples/inference)e. This will be officially released soon. This design document proposes the development of a new open platform for large distributed model inference in TorchServe. This platform will support popular libraries such as PyTorch native PiPPy, Microsoft Deepspeed, and HuggingFace Accelerate, making TorchServe an even more powerful framework for serving PyTorch models.

### Alternatives

_No response_

### Additional context

- [x] #2192 
- [x] #2207 
chauhang commented 1 year ago

Great to see the progress for large model inference. Few questions:

  1. For DeepSpeed and Accelerate, are there more details of how the distributed inference will be handled / simplified in the new solution. TorchServe already has examples for these HF Largemodel and DeepSpeed MII Stable Diffusion. For distributed inference is there anything else being proposed as part of this RFC

  2. What additional benefit is there for adding the torchrun integration? How will this work in case of kubernetes clusters, eg for the HF Bloom model using Kserve

lxning commented 1 year ago

@chauhang

1. For DeepSpeed and Accelerate, are there more details of how the distributed inference will be handled / simplified in the new solution?

The new solution streamlines the process of initializing the distribution environment for each worker by pre-checking if a host has sufficient GPUs to accommodate the model configuration requirements and configuring a set of device IDs as CUDA_VISIBLE_DEVICE for each worker. This approach not only simplifies the stack layer and reduces the overhead, but also makes debug easier. For example:

2. What additional benefit is there for adding the torchrun integration? How will this work in case of kubernetes clusters?

The new solution calls torchrun in frontend to launch a model's worker with multiple RPC jobs in the distribution cluster. It also pave the path to support multi-node distribution inference. There is no difference b/w on a KServe single host and on an EC2 host. HF accelerate is very similar as pippy. It call python multiprocess if distribution environment is not set up; otherwise it gets rank of the current job. (see accelerate.py and launcher.py).

msaroufim commented 1 year ago

Closing since this work was completed