Open sanghyuk-choi opened 7 months ago
I am also facing a similar issue, but mine is with default ddp strategy, not deepspeed.
Hey @isaacbmiller can you post a repro? There are few factor interacting here and it would help to take a direct look.
There's a discussion here (although it's about FSDP it still contains good insights) that describes what may be happening: https://github.com/TimDettmers/bitsandbytes/issues/89
That is, there are decisions being taken at the bitsandbytes level under the hood that may contribute.
Thanks for the help @lantiga
Going to respost this in that other thread because I think that is the actual problem, and I doubt that PL is the problem, as it seems to apply across Deepspeed, FSDP, and DDP. I haven't figured out the answer, but here is a semi-minimal repro (smallest I could get it). Still breaking on 2xA100.
Repro gist: https://gist.github.com/isaacbmiller/fc871d732d4d6a6b7ede3190a6979f40
nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-SXM... On | 00000000:4B:00.0 Off | 0 |
| N/A 46C P0 63W / 500W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA A100-SXM... On | 00000000:E3:00.0 Off | 0 |
| N/A 48C P0 65W / 500W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
deps
accelerate 0.28.0 pypi_0 pypi
aiohttp 3.9.3 pypi_0 pypi
aiosignal 1.3.1 pypi_0 pypi
annotated-types 0.6.0 pypi_0 pypi
antlr4-python3-runtime 4.9.3 pypi_0 pypi
anyio 4.3.0 pypi_0 pypi
appdirs 1.4.4 pypi_0 pypi
argon2-cffi 23.1.0 pypi_0 pypi
argon2-cffi-bindings 21.2.0 pypi_0 pypi
arrow 1.3.0 pypi_0 pypi
asttokens 2.4.1 pypi_0 pypi
async-lru 2.0.4 pypi_0 pypi
attrs 23.2.0 pypi_0 pypi
babel 2.14.0 pypi_0 pypi
beautifulsoup4 4.12.3 pypi_0 pypi
bitsandbytes 0.42.0 pypi_0 pypi
blas 1.0 mkl
bleach 6.1.0 pypi_0 pypi
blis 0.7.11 pypi_0 pypi
bzip2 1.0.8 h5eee18b_5
ca-certificates 2024.3.11 h06a4308_0
catalogue 2.0.10 pypi_0 pypi
certifi 2024.2.2 pypi_0 pypi
cffi 1.16.0 pypi_0 pypi
charset-normalizer 3.3.2 pypi_0 pypi
click 8.1.7 pypi_0 pypi
cloudpathlib 0.16.0 pypi_0 pypi
comm 0.2.2 pypi_0 pypi
confection 0.1.4 pypi_0 pypi
contourpy 1.2.0 pypi_0 pypi
cuda-cudart 12.1.105 0 nvidia
cuda-cupti 12.1.105 0 nvidia
cuda-libraries 12.1.0 0 nvidia
cuda-nvrtc 12.1.105 0 nvidia
cuda-nvtx 12.1.105 0 nvidia
cuda-opencl 12.4.127 0 nvidia
cuda-runtime 12.1.0 0 nvidia
cycler 0.12.1 pypi_0 pypi
cymem 2.0.8 pypi_0 pypi
datasets 2.14.7 pypi_0 pypi
debugpy 1.8.1 pypi_0 pypi
decorator 5.1.1 pypi_0 pypi
deepspeed 0.14.0 pypi_0 pypi
defusedxml 0.7.1 pypi_0 pypi
dill 0.3.7 pypi_0 pypi
docker-pycreds 0.4.0 pypi_0 pypi
editdistance 0.6.2 pypi_0 pypi
einops 0.8.0 pypi_0 pypi
executing 2.0.1 pypi_0 pypi
fastjsonschema 2.19.1 pypi_0 pypi
filelock 3.13.1 py311h06a4308_0
fonttools 4.50.0 pypi_0 pypi
fqdn 1.5.1 pypi_0 pypi
frozenlist 1.4.1 pypi_0 pypi
fsspec 2023.10.0 pypi_0 pypi
gcc 5.4.0 0 https://anaconda.org/brown-data-science/gcc/5.4.0/download
gitdb 4.0.11 pypi_0 pypi
gitpython 3.1.42 pypi_0 pypi
gmp 6.2.1 h295c915_3
gmpy2 2.1.2 py311hc9b5ff0_0
h11 0.14.0 pypi_0 pypi
hjson 3.1.0 pypi_0 pypi
httpcore 1.0.5 pypi_0 pypi
httpx 0.27.0 pypi_0 pypi
huggingface-hub 0.21.4 pypi_0 pypi
hydra-core 1.3.2 pypi_0 pypi
idna 3.6 pypi_0 pypi
iniconfig 2.0.0 pypi_0 pypi
intel-openmp 2023.1.0 hdb19cb5_46306
ipykernel 6.25.2 pypi_0 pypi
ipython 8.22.2 pypi_0 pypi
ipywidgets 8.1.2 pypi_0 pypi
isoduration 20.11.0 pypi_0 pypi
jedi 0.19.1 pypi_0 pypi
jinja2 3.1.3 py311h06a4308_0
joblib 1.3.2 pypi_0 pypi
json5 0.9.24 pypi_0 pypi
jsonpointer 2.4 pypi_0 pypi
jsonschema 4.21.1 pypi_0 pypi
jsonschema-specifications 2023.12.1 pypi_0 pypi
jupyter 1.0.0 pypi_0 pypi
jupyter-client 8.6.1 pypi_0 pypi
jupyter-console 6.6.3 pypi_0 pypi
jupyter-core 5.7.2 pypi_0 pypi
jupyter-events 0.10.0 pypi_0 pypi
jupyter-lsp 2.2.4 pypi_0 pypi
jupyter-server 2.13.0 pypi_0 pypi
jupyter-server-terminals 0.5.3 pypi_0 pypi
jupyterlab 4.1.5 pypi_0 pypi
jupyterlab-pygments 0.3.0 pypi_0 pypi
jupyterlab-server 2.25.4 pypi_0 pypi
jupyterlab-widgets 3.0.10 pypi_0 pypi
kiwisolver 1.4.5 pypi_0 pypi
langcodes 3.3.0 pypi_0 pypi
ld_impl_linux-64 2.38 h1181459_1
libcublas 12.1.0.26 0 nvidia
libcufft 11.0.2.4 0 nvidia
libcufile 1.9.1.3 0 nvidia
libcurand 10.3.5.147 0 nvidia
libcusolver 11.4.4.55 0 nvidia
libcusparse 12.0.2.55 0 nvidia
libffi 3.4.4 h6a678d5_0
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libnpp 12.0.2.50 0 nvidia
libnvjitlink 12.1.105 0 nvidia
libnvjpeg 12.1.1.14 0 nvidia
libstdcxx-ng 11.2.0 h1234567_1
libuuid 1.41.5 h5eee18b_0
lightning 2.2.1 pypi_0 pypi
lightning-utilities 0.11.2 pypi_0 pypi
lion-pytorch 0.1.4 pypi_0 pypi
llvm-openmp 14.0.6 h9e868ea_0
loralib 0.1.2 pypi_0 pypi
markdown-it-py 3.0.0 pypi_0 pypi
markupsafe 2.1.5 pypi_0 pypi
matplotlib 3.8.4 pypi_0 pypi
matplotlib-inline 0.1.6 pypi_0 pypi
mdurl 0.1.2 pypi_0 pypi
mistune 3.0.2 pypi_0 pypi
mkl 2023.1.0 h213fc3f_46344
mpc 1.1.0 h10f8cd9_1
mpfr 4.0.2 hb69a4c5_1
mpmath 1.3.0 py311h06a4308_0
multidict 6.0.5 pypi_0 pypi
multiprocess 0.70.15 pypi_0 pypi
murmurhash 1.0.10 pypi_0 pypi
nbclient 0.10.0 pypi_0 pypi
nbconvert 7.16.3 pypi_0 pypi
nbformat 5.10.3 pypi_0 pypi
ncurses 6.4 h6a678d5_0
nest-asyncio 1.6.0 pypi_0 pypi
networkx 3.2.1 pypi_0 pypi
ninja 1.11.1.1 pypi_0 pypi
nltk 3.8.1 pypi_0 pypi
notebook 7.1.2 pypi_0 pypi
notebook-shim 0.2.4 pypi_0 pypi
numpy 1.26.0 pypi_0 pypi
nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi
nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
nvidia-cudnn-cu12 8.9.2.26 pypi_0 pypi
nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi
nvidia-curand-cu12 10.3.2.106 pypi_0 pypi
nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi
nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi
nvidia-nccl-cu12 2.18.1 pypi_0 pypi
nvidia-nvjitlink-cu12 12.4.99 pypi_0 pypi
nvidia-nvtx-cu12 12.1.105 pypi_0 pypi
omegaconf 2.3.0 pypi_0 pypi
openssl 3.0.13 h7f8727e_0
overrides 7.7.0 pypi_0 pypi
packaging 24.0 pypi_0 pypi
pandas 2.2.2 pypi_0 pypi
pandocfilters 1.5.1 pypi_0 pypi
parso 0.8.3 pypi_0 pypi
pathtools 0.1.2 pypi_0 pypi
peft 0.5.0 pypi_0 pypi
pexpect 4.9.0 pypi_0 pypi
pillow 10.3.0 pypi_0 pypi
pip 24.0 pypi_0 pypi
platformdirs 4.2.0 pypi_0 pypi
pluggy 1.5.0 pypi_0 pypi
preshed 3.0.9 pypi_0 pypi
prometheus-client 0.20.0 pypi_0 pypi
prompt-toolkit 3.0.43 pypi_0 pypi
protobuf 4.25.3 pypi_0 pypi
psutil 5.9.8 pypi_0 pypi
ptyprocess 0.7.0 pypi_0 pypi
pure-eval 0.2.2 pypi_0 pypi
py-cpuinfo 9.0.0 pypi_0 pypi
pyarrow 15.0.2 pypi_0 pypi
pyarrow-hotfix 0.6 pypi_0 pypi
pycparser 2.22 pypi_0 pypi
pydantic 2.6.4 pypi_0 pypi
pydantic-core 2.16.3 pypi_0 pypi
pygments 2.17.2 pypi_0 pypi
pynvml 11.5.0 pypi_0 pypi
pyparsing 3.1.2 pypi_0 pypi
pytest 8.2.0 pypi_0 pypi
python 3.11.8 h955ad1f_0
python-dateutil 2.9.0.post0 pypi_0 pypi
python-json-logger 2.0.7 pypi_0 pypi
pytorch 2.2.2 py3.11_cuda12.1_cudnn8.9.2_0 pytorch
pytorch-cuda 12.1 ha16c6d3_5 pytorch
pytorch-lightning 2.2.1 pypi_0 pypi
pytorch-mutex 1.0 cuda pytorch
pytz 2024.1 pypi_0 pypi
pyyaml 6.0.1 py311h5eee18b_0
pyzmq 25.1.2 pypi_0 pypi
qtconsole 5.5.1 pypi_0 pypi
qtpy 2.4.1 pypi_0 pypi
readline 8.2 h5eee18b_0
referencing 0.34.0 pypi_0 pypi
regex 2023.12.25 pypi_0 pypi
requests 2.31.0 pypi_0 pypi
rfc3339-validator 0.1.4 pypi_0 pypi
rfc3986-validator 0.1.1 pypi_0 pypi
rich 13.6.0 pypi_0 pypi
rpds-py 0.18.0 pypi_0 pypi
safetensors 0.4.2 pypi_0 pypi
scikit-learn 1.4.1.post1 pypi_0 pypi
scipy 1.13.0 pypi_0 pypi
seaborn 0.13.2 pypi_0 pypi
send2trash 1.8.2 pypi_0 pypi
sentence-transformers 2.6.1 pypi_0 pypi
sentry-sdk 1.43.0 pypi_0 pypi
setproctitle 1.3.3 pypi_0 pypi
setuptools 68.2.2 py311h06a4308_0
six 1.16.0 pypi_0 pypi
smart-open 6.4.0 pypi_0 pypi
smmap 5.0.1 pypi_0 pypi
sniffio 1.3.1 pypi_0 pypi
soupsieve 2.5 pypi_0 pypi
spacy 3.7.4 pypi_0 pypi
spacy-legacy 3.0.12 pypi_0 pypi
spacy-loggers 1.0.5 pypi_0 pypi
sqlite 3.41.2 h5eee18b_0
srsly 2.4.8 pypi_0 pypi
stack-data 0.6.3 pypi_0 pypi
sympy 1.12 py311h06a4308_0
tbb 2021.8.0 hdb19cb5_0
terminado 0.18.1 pypi_0 pypi
thinc 8.2.3 pypi_0 pypi
threadpoolctl 3.4.0 pypi_0 pypi
tinycss2 1.2.1 pypi_0 pypi
tk 8.6.12 h1ccaba5_0
tokenizers 0.15.2 pypi_0 pypi
torch 2.1.0 pypi_0 pypi
torchdata 0.7.1 pypi_0 pypi
torchmetrics 1.3.2 pypi_0 pypi
torchtriton 2.2.0 py311 pytorch
tornado 6.4 pypi_0 pypi
tqdm 4.66.1 pypi_0 pypi
traitlets 5.14.2 pypi_0 pypi
transformers 4.40.0.dev0 pypi_0 pypi
triton 2.0.0.dev20221202 pypi_0 pypi
trl 0.7.1 pypi_0 pypi
typer 0.9.4 pypi_0 pypi
types-python-dateutil 2.9.0.20240316 pypi_0 pypi
typing-extensions 4.10.0 pypi_0 pypi
typing_extensions 4.9.0 py311h06a4308_1
tzdata 2024.1 pypi_0 pypi
uri-template 1.3.0 pypi_0 pypi
urllib3 2.2.1 pypi_0 pypi
wandb 0.15.12 pypi_0 pypi
wasabi 1.1.2 pypi_0 pypi
wcwidth 0.2.13 pypi_0 pypi
weasel 0.3.4 pypi_0 pypi
webcolors 1.13 pypi_0 pypi
webencodings 0.5.1 pypi_0 pypi
websocket-client 1.7.0 pypi_0 pypi
wheel 0.43.0 pypi_0 pypi
widgetsnbextension 4.0.10 pypi_0 pypi
xxhash 3.4.1 pypi_0 pypi
xz 5.4.6 h5eee18b_0
yaml 0.2.5 h7b6447c_0
yarl 1.9.4 pypi_0 pypi
zlib 1.2.13 h5eee18b_0
Exactly same problem here. Is there any update with this issue?
this works for me: https://github.com/Lightning-AI/pytorch-lightning/discussions/17878
maybe you can also try the following: https://github.com/huggingface/transformers/issues/28770
Bug description
Got RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)
File "/home/root/miniforge3/envs/torch2/lib/python3.9/site-packages/torch/nn/functional.py", line 2233, in embedding return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
This happens only when I use quantized model using BitsAndBytesConfig with DeepSpeed.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
Error messages and logs
Environment
Current environment
``` pytorch-lightning==2.2.1 pytorch=2.1.2 bitsandbytes=0.43.0 transformers=4.34.1 ```More info
Similar issue reported on huggingface discussion.
https://discuss.huggingface.co/t/runtimeerror-expected-all-tensors-to-be-on-the-same-device-but-found-at-least-two-devices-cuda-2-and-cuda-0-when-checking-argument-for-argument-index-in-method-wrapper-cuda-index-select/48991