Yard1 / Ray-DeepSpeed-Inference

16 stars 0 forks source link

Ray-DeepSpeed-Inference

EXPERIMENTAL AND NOT PRODUCTION READY! Many rough edges.

Based on https://github.com/microsoft/DeepSpeedExamples/tree/master/inference/huggingface/text-generation

How to run

Runs OPT-66b inference on a cluster composed of g4dn nodes (in my tests, 3 x g4dn.12xlarge, giving a total of 12 GPUs). You can also run it on 12 x g4dn.4xlarge.

python run_on_every_node.py download_model "s3://large-dl-models-mirror/models--anyscale--opt-66b-resharded/main/" "~/model"

python deepspeed_inference_actors.py --name "facebook/opt-66b" --checkpoint_path "~/model" --batch_size 1 --ds_inference --use_kernel --use_meta_tensor --num_worker_groups 1 --num_gpus_per_worker_group 12

How it works

This repository demonstrates how to use DeepSpeed Inference with Ray for scalable batch inference. The combination of these two tools allows for efficient generation of text with large language models, including models as large as OPT-66b.

DeepSpeed Inference utilizes automatic model parallelism to distribute the model across multiple GPUs. Ray handles the scheduling and orchestration of the workload.

There are three key parts to the code:

  1. deepspeed_inference_actors.py (the entrypoint) generates a sample Ray Dataset and uses ray.train.batch_predictor.BatchPredictor with a custom DeepSpeedPredictor. The BatchPredictor spawns num_worker_groups DeepSpeedPredictor actors, each recieving a share of the data.
  2. deepspeed_predictor.py contains the code for the DeepSpeedPredictor. Each DeepSpeedPredictor actor spawns num_gpus_per_worker_group worker actors (PredictionWorker), connected together via a torch.distributed backend, as required by DeepSpeed. Once initialized, the DeepSpeed model is ready for prediction.
  3. deepspeed_utils.py contains code based on a DeepSpeed example that is used by PredictionWorkers.

In other words, a DeepSpeedPredictor creates a worker group of PredictionWorker, which share a single model. A worker group is inelastic (if one worker fails, the entire group fails). This is similar to how Ray Train works (in fact, the logic can be implemented using Ray Train private APIs instead of PredictionWorker).

Known issues

  1. If there are multiple worker groups scheduled on one node, this will result in workers using the same CUDA devices and thus leading to a crash. Therefore, it's best to either use 1 GPU nodes, or make sure that the number of workers in a group divided by the number of nodes is equal to the number of GPUs on the nodes.
  2. Certain models obtained from Hugging Face hub will cause exceptions due to a bug in DeepSpeed. The solution is to reshard the checkpoints of those models to ensure that all layers are stored in contiguous files. The relevant code is included in huggingface_utils.py.

Environment

Key packages:

accelerate==0.17.1
deepspeed==0.8.3
ray @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
torch==2.0.0
transformers==4.27.2

All packages:

absl-py==1.4.0
accelerate==0.17.1
adal==1.2.7
aim==3.16.1
aim-ui==3.16.1
aimrecords==0.0.7
aimrocks==0.3.1
aiofiles==22.1.0
aiohttp==3.8.4
aiohttp-cors==0.7.0
aiorwlock==1.3.0
aiosignal==1.3.1
aiosqlite==0.18.0
ale-py==0.8.1
alembic==1.10.2
anyio==3.6.2
anyscale @ file:///home/ray/anyscale-0.0.0.dev0.tar.gz
anyscale-node-provider @ file:///home/ray/anyscale_node_provider-0.0.1.tar.gz
applicationinsights==0.11.10
argcomplete==1.12.3
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
astunparse==1.6.3
async-timeout==4.0.2
attrs==22.2.0
autocfg==0.0.8
autogluon.common==0.7.0
autogluon.core==0.7.0
autograd==1.5
autopage==0.5.1
AutoROM==0.6.0
AutoROM.accept-rom-license==0.6.0
awscli==1.25.6
awscliv2==2.2.0
ax-platform==0.3.1
azure-cli-core==2.40.0
azure-cli-telemetry==1.0.8
azure-common==1.1.28
azure-core==1.26.3
azure-identity==1.10.0
azure-mgmt-compute==23.1.0
azure-mgmt-core==1.3.2
azure-mgmt-network==19.0.0
azure-mgmt-resource==20.0.0
Babel==2.12.1
backcall==0.2.0
backoff==1.10.0
backports.zoneinfo==0.2.1
base58==2.0.1
bayesian-optimization==1.2.0
bcrypt==4.0.1
beautifulsoup4==4.12.0
bitsandbytes==0.37.2
black==23.1.0
bleach==6.0.0
blessed==1.20.0
blobfile==2.0.1
boto3==1.26.95
botocore==1.29.95
botorch==0.8.3
cached-property==1.5.2
cachetools==5.3.0
catboost==1.1.1
certifi==2022.12.7
cffi @ file:///tmp/abs_98z5h56wf8/croots/recipe/cffi_1659598650955/work
chardet==5.1.0
charset-normalizer==3.1.0
chess==1.7.0
chex==0.1.6
click==8.1.3
cliff==4.2.0
cloudpickle==2.2.1
cma==2.7.0
cmaes==0.9.1
cmake==3.26.0
cmd2==2.4.3
colorama==0.4.6
coloredlogs==15.0.1
colorful==0.5.5
colorlog==6.7.0
comet-ml==3.31.9
comm==0.1.2
commonmark==0.9.1
conda==23.1.0
conda-content-trust @ file:///tmp/abs_5952f1c8-355c-4855-ad2e-538535021ba5h26t22e5/croots/recipe/conda-content-trust_1658126371814/work
conda-package-handling @ file:///croot/conda-package-handling_1666940373510/work
configobj==5.0.8
ConfigSpace==0.4.18
contourpy==1.0.7
coolname==2.2.0
cryptography @ file:///croot/cryptography_1673298753778/work
cycler==0.11.0
Cython==0.29.32
databricks-cli==0.17.5
DataProperty==0.55.0
datasets==2.10.1
debugpy==1.6.6
decorator==5.1.1
decord==0.6.0
deepspeed==0.8.3
defusedxml==0.7.1
Deprecated==1.2.13
diffusers @ git+https://github.com/huggingface/diffusers.git@7fe88613fa15d230d59482889c440c7befa17c25
dill==0.3.6
distlib==0.3.6
dm-tree==0.1.8
docker==6.0.1
docker-pycreds==0.4.0
docutils==0.16
dopamine-rl==4.0.5
dragonfly-opt==0.1.6
dulwich==0.21.3
einops==0.3.0
entrypoints==0.4
etils==1.1.1
evaluate==0.4.0
everett==3.1.0
exceptiongroup==1.1.1
executing==1.2.0
executor==23.2
expiringdict==1.2.2
fastapi==0.95.0
fasteners==0.18
fastjsonschema==2.16.3
filelock==3.10.0
FLAML==1.1.1
Flask==2.2.3
flatbuffers==2.0.7
flax==0.6.7
fonttools==4.39.2
fqdn==1.5.1
freezegun==1.1.0
frozenlist==1.3.3
fsspec==2023.3.0
ftfy==6.1.1
future==0.18.3
gast==0.4.0
gin-config==0.5.0
gitdb==4.0.10
GitPython==3.1.31
glfw==2.5.7
gluoncv==0.10.1.post0
google-api-core==2.11.0
google-api-python-client==1.7.8
google-auth==2.16.2
google-auth-httplib2==0.1.0
google-auth-oauthlib==0.4.6
google-cloud-compute==1.10.1
google-cloud-core==2.3.2
google-cloud-resource-manager==1.9.0
google-cloud-secret-manager==2.16.0
google-cloud-storage==2.7.0
google-crc32c==1.5.0
google-oauth==1.0.1
google-pasta==0.2.0
google-resumable-media==2.4.1
googleapis-common-protos==1.58.0
gpustat==1.0.0
GPy==1.10.0
gpytorch==1.9.1
graphviz==0.8.4
greenlet==2.0.2
grpc-google-iam-v1==0.12.6
grpcio==1.51.3
grpcio-status==1.48.2
grpcio-tools==1.51.3
gunicorn==20.1.0
gym==0.26.2
gym-notices==0.0.8
Gymnasium==0.26.3
gymnasium-notices==0.0.1
h11==0.14.0
h5py==3.7.0
halo==0.0.31
HEBO==0.3.2
higher==0.2.1
hjson==3.1.0
hpbandster==0.7.4
httplib2==0.21.0
huggingface-hub==0.13.3
humanfriendly==10.0
humanize==4.6.0
hyperopt==0.2.5
idna==3.4
imageio==2.26.1
imageio-ffmpeg==0.4.5
importlib-metadata==6.1.0
importlib-resources==5.12.0
iniconfig==2.0.0
ipykernel==6.22.0
ipython==8.11.0
ipython-genutils==0.2.0
ipywidgets==8.0.4
isodate==0.6.1
isoduration==20.11.0
isort==5.12.0
itsdangerous==2.1.2
jax==0.4.6
jaxlib==0.4.6
jedi==0.18.2
Jinja2==3.1.2
jmespath==0.10.0
joblib==1.2.0
json5==0.9.11
jsonlines==3.1.0
jsonpatch==1.32
jsonpointer==2.3
jsonschema==4.17.3
jupyter-events==0.6.3
jupyter-ydoc==0.2.3
jupyter_client==8.1.0
jupyter_core==5.3.0
jupyter_server==2.5.0
jupyter_server_fileid==0.8.0
jupyter_server_terminals==0.4.4
jupyter_server_ydoc==0.6.1
jupyterlab==3.6.1
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.5
jupyterlab_server==2.20.0
kaggle-environments==1.7.11
keras==2.11.0
kiwisolver==1.4.4
knack==0.10.1
kubernetes==26.1.0
lazy_loader==0.1
libclang==15.0.6.1
libtorrent==2.0.7
lightgbm==3.3.5
lightgbm-ray==0.1.8
lightning-bolts==0.4.0
lightning-utilities==0.8.0
linear-operator==0.3.0
lit==16.0.0
lm-dataformat @ git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836
lm-eval==0.3.0
log-symbols==0.0.14
lxml==4.9.2
lz4==4.3.2
Mako==1.2.4
Markdown==3.4.1
markdown-it-py==2.2.0
MarkupSafe==2.1.2
matplotlib==3.7.1
matplotlib-inline==0.1.6
mbstrdecoder==1.1.2
mdurl==0.1.2
minigrid==2.1.1
mistune==2.0.5
mlagents-envs==0.28.0
mlflow==1.30.0
modin==0.18.1
monotonic==1.6
mosaicml==0.12.1
mpmath==1.3.0
msal==1.18.0b1
msal-extensions==1.0.0
msgpack==1.0.5
msrest==0.7.1
msrestazure==0.6.4
mujoco==2.2.0
mujoco-py==2.1.2.14
multidict==6.0.4
multipledispatch==0.6.0
multiprocess==0.70.14
mxnet==1.8.0.post0
mypy-extensions==1.0.0
nbclassic==0.5.3
nbclient==0.7.2
nbconvert==7.2.10
nbformat==5.8.0
nest-asyncio==1.5.6
netifaces==0.11.0
networkx==3.0
nevergrad==0.4.3.post7
ninja==1.11.1
nltk==3.8.1
notebook==6.5.3
notebook_shim==0.2.2
numexpr==2.8.4
numpy==1.23.5
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-ml-py==11.495.46
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
oauth2client==4.1.3
oauthlib==3.2.2
onnx==1.12.0
onnxruntime==1.14.1
open-spiel==1.2
openai==0.27.2
opencensus==0.11.2
opencensus-context==0.1.3
opencv-python==4.7.0.72
opentelemetry-api==1.1.0
opentelemetry-exporter-otlp==1.1.0
opentelemetry-exporter-otlp-proto-grpc==1.1.0
opentelemetry-exporter-otlp-proto-http==1.16.0
opentelemetry-proto==1.1.0
opentelemetry-sdk==1.1.0
opentelemetry-semantic-conventions==0.20b0
opt-einsum==3.3.0
optax==0.1.4
optuna==2.10.0
orbax==0.1.5
packaging==23.0
pandas==1.5.3
pandocfilters==1.5.0
paramiko==2.12.0
paramz==0.9.5
parso==0.8.3
pathspec==0.11.1
pathtools==0.1.2
pathvalidate==2.5.2
patsy==0.5.3
pbr==5.11.1
PettingZoo==1.22.1
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.4.0
pkginfo==1.9.6
pkgutil_resolve_name==1.3.10
platformdirs==3.1.1
plotly==5.13.1
pluggy @ file:///tmp/build/80754af9/pluggy_1648042571233/work
portalocker==2.7.0
prettytable==3.6.0
prometheus-client==0.13.1
prometheus-flask-exporter==0.22.3
promise==2.3
prompt-toolkit==3.0.38
property-manager==3.0
proto-plus==1.22.2
protobuf==3.20.3
psutil==5.9.4
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
py-spy==0.3.14
py3nvml==0.2.7
pyaml==21.10.1
pyarrow==11.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pybind11==2.6.2
pycosat @ file:///croot/pycosat_1666805502580/work
pycountry==22.3.5
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
pycryptodomex==3.17
pydantic==1.10.6
pyDeprecate==0.3.2
pygame==2.1.2
pyglet==1.5.15
Pygments==2.14.0
PyJWT==2.6.0
pymoo==0.5.0
pymunk==6.2.1
PyNaCl==1.5.0
PyOpenGL==3.1.6
pyOpenSSL==23.0.0
pyparsing==3.0.9
pyperclip==1.8.2
pypng==0.20220715.0
pyro-api==0.1.2
pyro-ppl==1.8.4
Pyro4==4.82
pyrsistent==0.19.3
PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
pytablewriter==0.64.2
pytest==7.2.2
pytest-remotedata==0.3.2
python-dateutil==2.8.2
python-json-logger==2.0.7
pytorch-lightning==2.0.0
pytorch-ranger==0.1.1
pytz==2022.7.1
pytz-deprecation-shim==0.1.0.post0
PyWavelets==1.4.1
PyYAML==6.0
pyzmq==25.0.2
querystring-parser==1.2.4
ray @ file:///home/ray/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
ray-lightning==0.3.0
recsim==0.2.4
redis==3.5.3
regex==2022.10.31
requests==2.28.2
requests-oauthlib==1.3.1
requests-toolbelt==0.10.1
responses==0.18.0
RestrictedPython==6.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==12.0.1
rouge-score==0.1.2
rsa==4.9
ruamel.yaml @ file:///croot/ruamel.yaml_1666304550667/work
ruamel.yaml.clib @ file:///croot/ruamel.yaml.clib_1666302247304/work
s3transfer==0.6.0
sacrebleu==1.5.0
scikit-image==0.20.0
scikit-learn==1.2.2
scikit-optimize==0.9.0
scipy==1.10.1
segment-analytics-python==2.2.2
semantic-version==2.10.0
Send2Trash==1.8.0
sentencepiece==0.1.96
sentry-sdk==1.17.0
serpent==1.41
setproctitle==1.3.2
shortuuid==1.0.1
sigopt==7.5.0
six==1.16.0
smart-open==6.3.0
smmap==5.0.0
sniffio==1.3.0
soupsieve==2.4
spinners==0.0.24
SQLAlchemy==1.4.47
sqlitedict==2.1.0
sqlparse==0.4.3
stack-data==0.6.2
starlette==0.26.1
statsmodels==0.13.5
stevedore==5.0.0
SuperSuit==3.7.0
sympy==1.11.1
tabledata==1.3.1
tabulate==0.9.0
tblib==1.7.0
tcolorpy==0.1.2
tenacity==8.2.2
tensorboard==2.12.0
tensorboard-data-server==0.7.0
tensorboard-plugin-wit==1.8.1
tensorboardX==2.4.1
tensorflow-estimator==2.11.0
tensorflow-io-gcs-filesystem==0.31.0
tensorflow-probability==0.19.0
tensorstore==0.1.33
termcolor==2.2.0
terminado==0.10.1
tf-slim==1.1.0
tf2onnx==1.13.0
threadpoolctl==3.1.0
tifffile==2023.3.15
tiktoken==0.1.2
timm==0.4.5
tinycss2==1.2.1
tinyscaler==1.2.5
tokenizers==0.13.2
tomli==2.0.1
toolz @ file:///croot/toolz_1667464077321/work
torch==2.0.0
torch-optimizer==0.3.0
torchaudio==2.0.1
torchmetrics==0.11.4
torchvision==0.15.1
tornado==6.2
tqdm==4.65.0
tqdm-multiprocess==0.0.11
traitlets==5.9.0
transformers==4.27.2
triton==2.0.0
tune-sklearn==0.4.4
typeguard==2.13.3
typepy==1.3.0
typer==0.6.1
typing_extensions==4.5.0
tzdata==2022.7
tzlocal==4.3
ujson==5.7.0
uri-template==1.2.0
uritemplate==3.0.1
urllib3==1.26.15
uvicorn==0.21.1
verboselogs==1.7
virtualenv==20.21.0
wandb==0.13.4
wcwidth==0.2.6
webcolors==1.12
webencodings==0.5.1
websocket-client==1.5.1
Werkzeug==2.2.3
widgetsnbextension==4.0.5
wrapt==1.15.0
wurlitzer==3.0.3
xgboost==1.7.4
xgboost-ray==0.1.15
xmltodict==0.13.0
xxhash==3.2.0
y-py==0.5.9
yacs==0.1.8
yarl==1.8.2
ypy-websocket==0.8.2
zipp==3.15.0
zoopt==0.4.1
zstandard==0.20.0