Open Cerrix opened 4 months ago
same here with x2 A100 80GB:
model.eval()
completion_request = ChatCompletionRequest(messages=[UserMessage(content=prompt)])
tokens = tokenizer.encode_chat_completion(completion_request).tokens
with torch.no_grad():
out_tokens, _ = generate([tokens], model, max_tokens=1024*2, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
answer = tokenizer.decode(out_tokens[0])
where 'prompt' contains roughly 50k tokens, resulting in:
Python 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] on linux
accelerate==0.32.1
adal==1.2.7
aiobotocore==2.13.1
aiohttp==3.9.5
aioitertools==0.11.0
aiosignal==1.3.1
annotated-types==0.7.0
argcomplete==3.4.0
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work
attrs==23.2.0
azure-ai-ml==1.18.0
azure-common==1.1.28
azure-core==1.30.2
azure-graphrbac==0.61.1
azure-identity==1.17.1
azure-mgmt-authorization==4.0.0
azure-mgmt-containerregistry==10.3.0
azure-mgmt-core==1.4.0
azure-mgmt-keyvault==10.3.1
azure-mgmt-network==25.4.0
azure-mgmt-resource==23.1.1
azure-mgmt-storage==21.2.1
azure-storage-blob==12.21.0
azure-storage-file-datalake==12.16.0
azure-storage-file-share==12.17.0
azureml-core==1.56.0
azureml-dataprep==5.1.6
azureml-dataprep-native==41.0.0
azureml-dataprep-rslex==2.22.2
azureml-dataset-runtime==1.56.0
azureml-defaults==1.56.0.post1
azureml-inference-server-http==1.2.2
backports.tempfile==1.0
backports.weakref==1.0.post1
bcrypt==4.2.0
bitsandbytes==0.43.1
blinker==1.8.2
botocore==1.34.131
cachetools==5.4.0
certifi==2024.7.4
cffi==1.16.0
charset-normalizer==3.3.2
ciso8601==2.3.1
click==8.1.7
cloudpickle==2.2.1
colorama==0.4.6
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work
contextlib2==21.6.0
contourpy==1.2.1
cryptography==43.0.0
cycler==0.12.1
darwin-rests==7.1.2+moreni
datasets==2.20.0
debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1719378645730/work
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
devolab2igam==0.0.101
dill==0.3.8
docker==7.1.0
docstring_parser==0.16
docx2txt==0.8
ecb-certifi==4.3.0+moreni
einops==0.8.0
et-xmlfile==1.1.0
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1720869315914/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
fasteners==0.19
filelock==3.15.4
fire==0.6.0
flash_attn==2.6.1
Flask==2.3.2
Flask-Cors==3.0.10
fonttools==4.53.1
frozenlist==1.4.1
fsspec==2024.6.1
fusepy==3.0.1
google-api-core==2.19.1
google-auth==2.32.0
googleapis-common-protos==1.63.2
gssapi==1.8.3
gunicorn==22.0.0
huggingface-hub==0.24.0
humanfriendly==10.0
idna==3.7
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1719361860083/work
inference-schema==1.7.2
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1719582526268/work
ipywidgets @ file:///home/conda/feedstock_root/build_artifacts/ipywidgets_1716897651763/work
isodate==0.6.1
itsdangerous==2.2.0
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
jeepney==0.8.0
Jinja2==3.1.4
jmespath==1.0.1
joblib==1.4.2
jsonpickle==3.2.2
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1716472197302/work
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257359434/work
jupyterlab_widgets @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_widgets_1716891641122/work
kiwisolver==1.4.5
knack==0.11.0
krb5==0.6.0
lxml==5.2.2
MarkupSafe==2.1.5
marshmallow==3.21.3
matplotlib==3.9.1
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1713250518406/work
mistral_common==1.3.3
mistral_finetune==0.0.0
mistral_inference==1.3.1
mpmath==1.3.0
msal==1.30.0
msal-extensions==1.2.0
msrest==0.7.1
msrestazure==0.6.4
multidict==6.0.5
multiprocess==0.70.16
ndg-httpsclient==0.5.1
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
networkx==3.3
ninja==1.11.1.1
numpy==1.23.5
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.5.82
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
opencensus==0.11.4
opencensus-context==0.1.3
opencensus-ext-azure==1.1.13
opencensus-ext-logging==0.1.1
openpyxl==3.1.5
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1718189413536/work
pandas==2.1.4
paramiko==3.4.0
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1712320355065/work
pathspec==0.12.1
patsy==0.5.6
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
pillow==10.4.0
pkginfo==1.11.1
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1715777629804/work
plotly==5.22.0
plotly-express==0.4.1
portalocker==2.10.1
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1718047967974/work
proto-plus==1.24.0
protobuf==5.27.2
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1719274586160/work
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1721585709575/work
pyarrow==17.0.0
pyarrow-hotfix==0.6
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycparser==2.22
pydantic==2.8.2
pydantic-settings==2.3.4
pydantic_core==2.20.1
pydash==8.0.3
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1714846767233/work
PyJWT==2.8.0
PyNaCl==1.5.0
pynvml==11.5.3
pyodbc==5.0.0
pyOpenSSL==24.2.1
pyparsing==3.1.2
PySocks==1.7.1
pyspnego==0.11.0
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1709299778482/work
python-docx==1.1.2
python-dotenv==1.0.1
python-gitlab==4.8.0
python-slugify==8.0.4
pytz==2024.1
PyYAML==6.0.1
pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1715024370414/work
referencing==0.35.1
regex==2024.5.15
requests==2.32.3
requests-kerberos==0.15.0
requests-oauthlib==2.0.0
requests-toolbelt==1.0.0
rpds-py==0.19.0
rsa==4.9
s3fs==2024.6.1
safetensors==0.4.3
scikit-learn==1.5.1
scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy-split_1720323007424/work/dist/scipy-1.14.0-cp311-cp311-linux_x86_64.whl#sha256=1555805d3d22eadcd79d8bbf4de2865c7ad881feceb57d3c2d91ec2469d4acf7
SecretStorage==3.3.3
sentencepiece==0.2.0
simple_parsing==0.1.5
simply-rest==4.0
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
statsmodels==0.14.2
strictyaml==1.7.3
sympy==1.13.1
tabulate==0.9.0
tenacity==8.5.0
termcolor==2.4.0
text-unidecode==1.3
threadpoolctl==3.5.0
tiktoken==0.7.0
tokenizers==0.19.1
torch==2.3.1
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1717722848697/work
tqdm==4.66.4
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1713535121073/work
transformers==4.42.4
triton==2.3.1
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1717802530399/work
tzdata==2024.1
tzlocal==5.2
urllib3==2.2.2
vl_connect==0.1.41
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
Werkzeug==3.0.3
widgetsnbextension @ file:///home/conda/feedstock_root/build_artifacts/widgetsnbextension_1716891659446/work
wrapt==1.16.0
xformers==0.0.27
xxhash==3.4.1
yarl==1.9.4
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1718013267051/work
Same issue on 4 NVIDIA-A10, only 1 A10 is used and the rest GPUs remain empty, but "Out of Memory" error occurs!
You might want to try the vLLM library. I used that to deploy the Mistral-nemo model in a multi-node, multi-gpu setting. Reference: https://docs.mistral.ai/deployment/self-deployment/vllm/
I could be wrong, but I think vLLM library also has cpu-offload capability for 1 GPU settings.
It's slower than mistral-inference for obvious reasons, but it's better than nothing.
Hey @Cerrix you need to load the model with pipeline parallelism enabled e.g. see:
https://github.com/mistralai/mistral-inference?tab=readme-ov-file#cli - specifically:
torchrun --nproc-per-node 2 (your script)
Also make sure to define pipeline parallelism as shown here: https://github.com/mistralai/mistral-inference/blob/fffa5dac372280e5810d8008e54f70b1a5c40bde/src/mistral_inference/main.py#L124
Python -VV
Pip Freeze
Reproduction Steps
Running the following code with a model such as Nemo Instruct (which can not be stored onto a single GPU):
lead to teh following error: "OutOfMemoryError: CUDA out of memory. Tried to allocate 140.00 MiB. GPU"
This is because, as you can see in the attached screenshot, it is loaded onto one single GPU
I looked for a parameter inside the Transformer python module but I don't see nothing to enable the multi-gpu inference.
Thank you so much
Expected Behavior
I would expect to see the model loaded onto multiple GPUs automatically as in the screenshot
Additional Context
No response
Suggested Solutions
I would recommend to add a parameter such as the device_map parameter of Hugging Face Transformer library: https://huggingface.co/docs/transformers/main_classes/pipelines. Or distribute the model automatically