Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
Apache License 2.0
28.35k stars 3.39k forks source link

Expected all tensors to be on the same device... using quantized model with deepspeed zero3 #19731

Open sanghyuk-choi opened 7 months ago

sanghyuk-choi commented 7 months ago

Bug description

Got RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

File "/home/root/miniforge3/envs/torch2/lib/python3.9/site-packages/torch/nn/", line 2233, in embedding return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

This happens only when I use quantized model using BitsAndBytesConfig with DeepSpeed.

What version are you seeing the problem on?


How to reproduce the bug

from transformers import AutoModelForCausalLM, BitsAndBytesConfig

class BoringModel(LightningModule):
    def __init__(self):

        bnb_config = BitsAndBytesConfig(
        model = AutoModelForCausalLM.from_pretrained(

    def forward(self, batch):
        outputs = self.model(**batch)


trainer = Trainer(
    strategy="deepspeed_stage_3"), train_data_loader, test_data_loader)

Error messages and logs

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)


Current environment ``` pytorch-lightning==2.2.1 pytorch=2.1.2 bitsandbytes=0.43.0 transformers=4.34.1 ```

More info

Similar issue reported on huggingface discussion.

isaacbmiller commented 6 months ago

I am also facing a similar issue, but mine is with default ddp strategy, not deepspeed.

lantiga commented 6 months ago

Hey @isaacbmiller can you post a repro? There are few factor interacting here and it would help to take a direct look.

There's a discussion here (although it's about FSDP it still contains good insights) that describes what may be happening:

That is, there are decisions being taken at the bitsandbytes level under the hood that may contribute.

isaacbmiller commented 6 months ago

Thanks for the help @lantiga

Going to respost this in that other thread because I think that is the actual problem, and I doubt that PL is the problem, as it seems to apply across Deepspeed, FSDP, and DDP. I haven't figured out the answer, but here is a semi-minimal repro (smallest I could get it). Still breaking on 2xA100.

Repro gist:


| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  On   | 00000000:4B:00.0 Off |                    0 |
| N/A   46C    P0    63W / 500W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
|   1  NVIDIA A100-SXM...  On   | 00000000:E3:00.0 Off |                    0 |
| N/A   48C    P0    65W / 500W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |


accelerate                0.28.0                   pypi_0    pypi
aiohttp                   3.9.3                    pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
annotated-types           0.6.0                    pypi_0    pypi
antlr4-python3-runtime    4.9.3                    pypi_0    pypi
anyio                     4.3.0                    pypi_0    pypi
appdirs                   1.4.4                    pypi_0    pypi
argon2-cffi               23.1.0                   pypi_0    pypi
argon2-cffi-bindings      21.2.0                   pypi_0    pypi
arrow                     1.3.0                    pypi_0    pypi
asttokens                 2.4.1                    pypi_0    pypi
async-lru                 2.0.4                    pypi_0    pypi
attrs                     23.2.0                   pypi_0    pypi
babel                     2.14.0                   pypi_0    pypi
beautifulsoup4            4.12.3                   pypi_0    pypi
bitsandbytes              0.42.0                   pypi_0    pypi
blas                      1.0                         mkl  
bleach                    6.1.0                    pypi_0    pypi
blis                      0.7.11                   pypi_0    pypi
bzip2                     1.0.8                h5eee18b_5  
ca-certificates           2024.3.11            h06a4308_0  
catalogue                 2.0.10                   pypi_0    pypi
certifi                   2024.2.2                 pypi_0    pypi
cffi                      1.16.0                   pypi_0    pypi
charset-normalizer        3.3.2                    pypi_0    pypi
click                     8.1.7                    pypi_0    pypi
cloudpathlib              0.16.0                   pypi_0    pypi
comm                      0.2.2                    pypi_0    pypi
confection                0.1.4                    pypi_0    pypi
contourpy                 1.2.0                    pypi_0    pypi
cuda-cudart               12.1.105                      0    nvidia
cuda-cupti                12.1.105                      0    nvidia
cuda-libraries            12.1.0                        0    nvidia
cuda-nvrtc                12.1.105                      0    nvidia
cuda-nvtx                 12.1.105                      0    nvidia
cuda-opencl               12.4.127                      0    nvidia
cuda-runtime              12.1.0                        0    nvidia
cycler                    0.12.1                   pypi_0    pypi
cymem                     2.0.8                    pypi_0    pypi
datasets                  2.14.7                   pypi_0    pypi
debugpy                   1.8.1                    pypi_0    pypi
decorator                 5.1.1                    pypi_0    pypi
deepspeed                 0.14.0                   pypi_0    pypi
defusedxml                0.7.1                    pypi_0    pypi
dill                      0.3.7                    pypi_0    pypi
docker-pycreds            0.4.0                    pypi_0    pypi
editdistance              0.6.2                    pypi_0    pypi
einops                    0.8.0                    pypi_0    pypi
executing                 2.0.1                    pypi_0    pypi
fastjsonschema            2.19.1                   pypi_0    pypi
filelock                  3.13.1          py311h06a4308_0  
fonttools                 4.50.0                   pypi_0    pypi
fqdn                      1.5.1                    pypi_0    pypi
frozenlist                1.4.1                    pypi_0    pypi
fsspec                    2023.10.0                pypi_0    pypi
gcc                       5.4.0                         0
gitdb                     4.0.11                   pypi_0    pypi
gitpython                 3.1.42                   pypi_0    pypi
gmp                       6.2.1                h295c915_3  
gmpy2                     2.1.2           py311hc9b5ff0_0  
h11                       0.14.0                   pypi_0    pypi
hjson                     3.1.0                    pypi_0    pypi
httpcore                  1.0.5                    pypi_0    pypi
httpx                     0.27.0                   pypi_0    pypi
huggingface-hub           0.21.4                   pypi_0    pypi
hydra-core                1.3.2                    pypi_0    pypi
idna                      3.6                      pypi_0    pypi
iniconfig                 2.0.0                    pypi_0    pypi
intel-openmp              2023.1.0         hdb19cb5_46306  
ipykernel                 6.25.2                   pypi_0    pypi
ipython                   8.22.2                   pypi_0    pypi
ipywidgets                8.1.2                    pypi_0    pypi
isoduration               20.11.0                  pypi_0    pypi
jedi                      0.19.1                   pypi_0    pypi
jinja2                    3.1.3           py311h06a4308_0  
joblib                    1.3.2                    pypi_0    pypi
json5                     0.9.24                   pypi_0    pypi
jsonpointer               2.4                      pypi_0    pypi
jsonschema                4.21.1                   pypi_0    pypi
jsonschema-specifications 2023.12.1                pypi_0    pypi
jupyter                   1.0.0                    pypi_0    pypi
jupyter-client            8.6.1                    pypi_0    pypi
jupyter-console           6.6.3                    pypi_0    pypi
jupyter-core              5.7.2                    pypi_0    pypi
jupyter-events            0.10.0                   pypi_0    pypi
jupyter-lsp               2.2.4                    pypi_0    pypi
jupyter-server            2.13.0                   pypi_0    pypi
jupyter-server-terminals  0.5.3                    pypi_0    pypi
jupyterlab                4.1.5                    pypi_0    pypi
jupyterlab-pygments       0.3.0                    pypi_0    pypi
jupyterlab-server         2.25.4                   pypi_0    pypi
jupyterlab-widgets        3.0.10                   pypi_0    pypi
kiwisolver                1.4.5                    pypi_0    pypi
langcodes                 3.3.0                    pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1  
libcublas                            0    nvidia
libcufft                              0    nvidia
libcufile                              0    nvidia
libcurand                           0    nvidia
libcusolver                          0    nvidia
libcusparse                          0    nvidia
libffi                    3.4.4                h6a678d5_0  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libnpp                               0    nvidia
libnvjitlink              12.1.105                      0    nvidia
libnvjpeg                            0    nvidia
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
lightning                 2.2.1                    pypi_0    pypi
lightning-utilities       0.11.2                   pypi_0    pypi
lion-pytorch              0.1.4                    pypi_0    pypi
llvm-openmp               14.0.6               h9e868ea_0  
loralib                   0.1.2                    pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                2.1.5                    pypi_0    pypi
matplotlib                3.8.4                    pypi_0    pypi
matplotlib-inline         0.1.6                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
mistune                   3.0.2                    pypi_0    pypi
mkl                       2023.1.0         h213fc3f_46344  
mpc                       1.1.0                h10f8cd9_1  
mpfr                      4.0.2                hb69a4c5_1  
mpmath                    1.3.0           py311h06a4308_0  
multidict                 6.0.5                    pypi_0    pypi
multiprocess              0.70.15                  pypi_0    pypi
murmurhash                1.0.10                   pypi_0    pypi
nbclient                  0.10.0                   pypi_0    pypi
nbconvert                 7.16.3                   pypi_0    pypi
nbformat                  5.10.3                   pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
nest-asyncio              1.6.0                    pypi_0    pypi
networkx                  3.2.1                    pypi_0    pypi
ninja                            pypi_0    pypi
nltk                      3.8.1                    pypi_0    pypi
notebook                  7.1.2                    pypi_0    pypi
notebook-shim             0.2.4                    pypi_0    pypi
numpy                     1.26.0                   pypi_0    pypi
nvidia-cublas-cu12                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12                 pypi_0    pypi
nvidia-cufft-cu12                pypi_0    pypi
nvidia-curand-cu12               pypi_0    pypi
nvidia-cusolver-cu12               pypi_0    pypi
nvidia-cusparse-cu12               pypi_0    pypi
nvidia-nccl-cu12          2.18.1                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.99                  pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
omegaconf                 2.3.0                    pypi_0    pypi
openssl                   3.0.13               h7f8727e_0  
overrides                 7.7.0                    pypi_0    pypi
packaging                 24.0                     pypi_0    pypi
pandas                    2.2.2                    pypi_0    pypi
pandocfilters             1.5.1                    pypi_0    pypi
parso                     0.8.3                    pypi_0    pypi
pathtools                 0.1.2                    pypi_0    pypi
peft                      0.5.0                    pypi_0    pypi
pexpect                   4.9.0                    pypi_0    pypi
pillow                    10.3.0                   pypi_0    pypi
pip                       24.0                     pypi_0    pypi
platformdirs              4.2.0                    pypi_0    pypi
pluggy                    1.5.0                    pypi_0    pypi
preshed                   3.0.9                    pypi_0    pypi
prometheus-client         0.20.0                   pypi_0    pypi
prompt-toolkit            3.0.43                   pypi_0    pypi
protobuf                  4.25.3                   pypi_0    pypi
psutil                    5.9.8                    pypi_0    pypi
ptyprocess                0.7.0                    pypi_0    pypi
pure-eval                 0.2.2                    pypi_0    pypi
py-cpuinfo                9.0.0                    pypi_0    pypi
pyarrow                   15.0.2                   pypi_0    pypi
pyarrow-hotfix            0.6                      pypi_0    pypi
pycparser                 2.22                     pypi_0    pypi
pydantic                  2.6.4                    pypi_0    pypi
pydantic-core             2.16.3                   pypi_0    pypi
pygments                  2.17.2                   pypi_0    pypi
pynvml                    11.5.0                   pypi_0    pypi
pyparsing                 3.1.2                    pypi_0    pypi
pytest                    8.2.0                    pypi_0    pypi
python                    3.11.8               h955ad1f_0  
python-dateutil           2.9.0.post0              pypi_0    pypi
python-json-logger        2.0.7                    pypi_0    pypi
pytorch                   2.2.2           py3.11_cuda12.1_cudnn8.9.2_0    pytorch
pytorch-cuda              12.1                 ha16c6d3_5    pytorch
pytorch-lightning         2.2.1                    pypi_0    pypi
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2024.1                   pypi_0    pypi
pyyaml                    6.0.1           py311h5eee18b_0  
pyzmq                     25.1.2                   pypi_0    pypi
qtconsole                 5.5.1                    pypi_0    pypi
qtpy                      2.4.1                    pypi_0    pypi
readline                  8.2                  h5eee18b_0  
referencing               0.34.0                   pypi_0    pypi
regex                     2023.12.25               pypi_0    pypi
requests                  2.31.0                   pypi_0    pypi
rfc3339-validator         0.1.4                    pypi_0    pypi
rfc3986-validator         0.1.1                    pypi_0    pypi
rich                      13.6.0                   pypi_0    pypi
rpds-py                   0.18.0                   pypi_0    pypi
safetensors               0.4.2                    pypi_0    pypi
scikit-learn              1.4.1.post1              pypi_0    pypi
scipy                     1.13.0                   pypi_0    pypi
seaborn                   0.13.2                   pypi_0    pypi
send2trash                1.8.2                    pypi_0    pypi
sentence-transformers     2.6.1                    pypi_0    pypi
sentry-sdk                1.43.0                   pypi_0    pypi
setproctitle              1.3.3                    pypi_0    pypi
setuptools                68.2.2          py311h06a4308_0  
six                       1.16.0                   pypi_0    pypi
smart-open                6.4.0                    pypi_0    pypi
smmap                     5.0.1                    pypi_0    pypi
sniffio                   1.3.1                    pypi_0    pypi
soupsieve                 2.5                      pypi_0    pypi
spacy                     3.7.4                    pypi_0    pypi
spacy-legacy              3.0.12                   pypi_0    pypi
spacy-loggers             1.0.5                    pypi_0    pypi
sqlite                    3.41.2               h5eee18b_0  
srsly                     2.4.8                    pypi_0    pypi
stack-data                0.6.3                    pypi_0    pypi
sympy                     1.12            py311h06a4308_0  
tbb                       2021.8.0             hdb19cb5_0  
terminado                 0.18.1                   pypi_0    pypi
thinc                     8.2.3                    pypi_0    pypi
threadpoolctl             3.4.0                    pypi_0    pypi
tinycss2                  1.2.1                    pypi_0    pypi
tk                        8.6.12               h1ccaba5_0  
tokenizers                0.15.2                   pypi_0    pypi
torch                     2.1.0                    pypi_0    pypi
torchdata                 0.7.1                    pypi_0    pypi
torchmetrics              1.3.2                    pypi_0    pypi
torchtriton               2.2.0                     py311    pytorch
tornado                   6.4                      pypi_0    pypi
tqdm                      4.66.1                   pypi_0    pypi
traitlets                 5.14.2                   pypi_0    pypi
transformers              4.40.0.dev0              pypi_0    pypi
triton                    2.0.0.dev20221202          pypi_0    pypi
trl                       0.7.1                    pypi_0    pypi
typer                     0.9.4                    pypi_0    pypi
types-python-dateutil           pypi_0    pypi
typing-extensions         4.10.0                   pypi_0    pypi
typing_extensions         4.9.0           py311h06a4308_1  
tzdata                    2024.1                   pypi_0    pypi
uri-template              1.3.0                    pypi_0    pypi
urllib3                   2.2.1                    pypi_0    pypi
wandb                     0.15.12                  pypi_0    pypi
wasabi                    1.1.2                    pypi_0    pypi
wcwidth                   0.2.13                   pypi_0    pypi
weasel                    0.3.4                    pypi_0    pypi
webcolors                 1.13                     pypi_0    pypi
webencodings              0.5.1                    pypi_0    pypi
websocket-client          1.7.0                    pypi_0    pypi
wheel                     0.43.0                   pypi_0    pypi
widgetsnbextension        4.0.10                   pypi_0    pypi
xxhash                    3.4.1                    pypi_0    pypi
xz                        5.4.6                h5eee18b_0  
yaml                      0.2.5                h7b6447c_0  
yarl                      1.9.4                    pypi_0    pypi
zlib                      1.2.13               h5eee18b_0
FractureSR commented 2 months ago

Exactly same problem here. Is there any update with this issue?

Crazy-LittleBoy commented 2 months ago

this works for me:

maybe you can also try the following: