Closed lxning closed 1 year ago
Great to see the progress for large model inference. Few questions:
For DeepSpeed and Accelerate, are there more details of how the distributed inference will be handled / simplified in the new solution. TorchServe already has examples for these HF Largemodel and DeepSpeed MII Stable Diffusion. For distributed inference is there anything else being proposed as part of this RFC
What additional benefit is there for adding the torchrun
integration? How will this work in case of kubernetes clusters, eg for the HF Bloom model using Kserve
@chauhang
The new solution streamlines the process of initializing the distribution environment for each worker by pre-checking if a host has sufficient GPUs to accommodate the model configuration requirements and configuring a set of device IDs as CUDA_VISIBLE_DEVICE for each worker. This approach not only simplifies the stack layer and reduces the overhead, but also makes debug easier. For example:
The new solution calls torchrun in frontend to launch a model's worker with multiple RPC jobs in the distribution cluster. It also pave the path to support multi-node distribution inference. There is no difference b/w on a KServe single host and on an EC2 host. HF accelerate is very similar as pippy. It call python multiprocess if distribution environment is not set up; otherwise it gets rank of the current job. (see accelerate.py and launcher.py).
Closing since this work was completed
🚀 The feature
TorchServe Open Platform for Large Distributed Model Inference
Authors: Li Ning, Hamid Shojanazeri, Ke Wen
As the size of machine learning models grows, distributed inference becomes increasingly necessary. To address this need, previously we published an RFC calling for a PyTorch native solution. This has been developed as PiPPy (Pipeline Parallelism for PyTorch), please read more using the example. This will be officially released soon. This design document proposes the development of a new open platform for large distributed model inference in TorchServe.This platform will support popular libraries such as PyTorch native PiPPy, Microsoft Deepspeed, and HuggingFace Accelerate, making TorchServe an even more powerful framework for serving PyTorch models.
Goals
The primary goal of this project is to extend the large model inference support in Torchserve to PyTorch native solution and other popular solutions. The platform should provide an easy-to-use interface for users and support existing handler APIs. Other goals of the platform include:
Requirements
Distributed inference support
The platform should support distributed inference using popular libraries such as PyTorch native PiPPy, Microsoft Deepspeed, and HuggingFace Accelerate. This will require changes to the TorchServe core to support these libraries as plugins.
Multi-GPU and cluster support
The platform should provide seamless support for distributed inference across multiple GPU devices and clusters. To achieve this, changes will be required in the TorchServe core to manage multiple devices and clusters.
Existing handler API support
The platform should ensure backward compatibility with existing TorchServe handler APIs, providing users with the option to either utilize the default distributed inference handler or easily customize it to their specific needs.
Flexibility and extensibility
TorchServe fosters an inclusive and welcoming environment that encourages community contribution by providing an open platform for collaboration and development. The platform should be flexible and extensible to allow for integration with other libraries and technologies.
Consistent user experience
TorchServe ensures a consistent user experience for both large distributed model inference and non-distributed model inference. With TorchServe, a single server can handle 1 or more workers for a large distributed model and can load multiple models, including 1 or more large distributed models and 1 or more non-distributed models.
User Experience Description
TorchServe model mar file
torchrun: max_restarts: 3
TS backend parameters
pippy: rpc_timeout: 1800 pp_group_size: 4 # pipeline parallel size, tp_group_size = world size / pp_group_size
torch-model-archiver --model-name bloom --version 1.0 --handler pippy_pipeline.py --extra-files model.zip,setup_config.json -r requirements.txt -c model-config.yaml
ll bloom.mar -rw-rw-r-- 1 ubuntu ubuntu 3952930456 Mar 06 19:35 bloom.mar
cat MAR-INF/MANIFEST.json { "createdOn": "06/03/2023 03:04:14", "runtime": "python", "model": { "modelName": "bloom", "handler": "pippy_pipeline.py", "modelVersion": "1.0", "requirementsFile": "requirements.txt", "configFile": "model-config.yaml" }, "archiverVersion": "0.7.0"
cat requirements.txt transformers accelerate
usage: torch-model-archiver [-h] --model-name MODEL_NAME [--serialized-file SERIALIZED_FILE] [--model-file MODEL_FILE] --handler HANDLER [--extra-files EXTRA_FILES] [--runtime {python,python3}] [--export-path EXPORT_PATH] [--archive-format {tgz,no-archive,default}] [-f] -v VERSION [-r REQUIREMENTS_FILE] [-c CONFIG_FILE]
Worker cmdline: [torchrun, --nnodes=1, --nproc_per_node=4, --max_restarts=3, --log_dir=/tmp/torchelastic_ts, --rdzv_backend=c10d, --rdzv_endpoint=localhost:29500, --rdzv_id=bloom_29500, /opt/conda/envs/py38/lib/python3.8/site-packages/ts/model_service_worker.py, --sock-type, unix, --sock-name, /tmp/.ts.sock.29500, --metrics-config, /opt/conda/envs/py38/lib/python3.8/site-packages/ts/configs/metrics.yaml]
ctx.model_yaml_config["pippy"]["rpc_timeout"].
ts/torch_handler ├── init.py ├── base_handler.py ├── contractions.py ├── densenet_handler.py ├── distributed │ ├── base_accelerate.py │ ├── base_deepspeed.py │ ├── base_pippy_pipeline.py │ └── base_pippy_pptp.py ├── image_classifier.py ├── image_segmenter.py ├── object_detector.py