Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.35k stars 3.39k forks source link

Expected all tensors to be on the same device... using quantized model with deepspeed zero3 #19731

Open sanghyuk-choi opened 7 months ago

sanghyuk-choi commented 7 months ago

Bug description

Got RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

File "/home/root/miniforge3/envs/torch2/lib/python3.9/site-packages/torch/nn/functional.py", line 2233, in embedding return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

This happens only when I use quantized model using BitsAndBytesConfig with DeepSpeed.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

from transformers import AutoModelForCausalLM, BitsAndBytesConfig

class BoringModel(LightningModule):
    def __init__(self):

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype="bfloat16",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_storage="bfloat16"
        )
        model = AutoModelForCausalLM.from_pretrained(
                "mistralai/Mistral-7B-v0.1",
                quantization_config=bnb_config,
                **model_kwargs)

    def forward(self, batch):
        outputs = self.model(**batch)

    ...

trainer = Trainer(
    accelerator="gpu",
    devices=2,
    accumulate_grad_batches=1,
    strategy="deepspeed_stage_3")

trainer.fit(model, train_data_loader, test_data_loader)

Error messages and logs

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

Environment

Current environment ``` pytorch-lightning==2.2.1 pytorch=2.1.2 bitsandbytes=0.43.0 transformers=4.34.1 ```

More info

Similar issue reported on huggingface discussion.

https://discuss.huggingface.co/t/runtimeerror-expected-all-tensors-to-be-on-the-same-device-but-found-at-least-two-devices-cuda-2-and-cuda-0-when-checking-argument-for-argument-index-in-method-wrapper-cuda-index-select/48991

isaacbmiller commented 6 months ago

I am also facing a similar issue, but mine is with default ddp strategy, not deepspeed.

lantiga commented 6 months ago

Hey @isaacbmiller can you post a repro? There are few factor interacting here and it would help to take a direct look.

There's a discussion here (although it's about FSDP it still contains good insights) that describes what may be happening: https://github.com/TimDettmers/bitsandbytes/issues/89

That is, there are decisions being taken at the bitsandbytes level under the hood that may contribute.

isaacbmiller commented 6 months ago

Thanks for the help @lantiga

Going to respost this in that other thread because I think that is the actual problem, and I doubt that PL is the problem, as it seems to apply across Deepspeed, FSDP, and DDP. I haven't figured out the answer, but here is a semi-minimal repro (smallest I could get it). Still breaking on 2xA100.

Repro gist: https://gist.github.com/isaacbmiller/fc871d732d4d6a6b7ede3190a6979f40

nvidia-smi

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:4B:00.0 Off |                    0 |
| N/A   46C    P0    63W / 500W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  On   | 00000000:E3:00.0 Off |                    0 |
| N/A   48C    P0    65W / 500W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

deps

accelerate                0.28.0                   pypi_0    pypi
aiohttp                   3.9.3                    pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
annotated-types           0.6.0                    pypi_0    pypi
antlr4-python3-runtime    4.9.3                    pypi_0    pypi
anyio                     4.3.0                    pypi_0    pypi
appdirs                   1.4.4                    pypi_0    pypi
argon2-cffi               23.1.0                   pypi_0    pypi
argon2-cffi-bindings      21.2.0                   pypi_0    pypi
arrow                     1.3.0                    pypi_0    pypi
asttokens                 2.4.1                    pypi_0    pypi
async-lru                 2.0.4                    pypi_0    pypi
attrs                     23.2.0                   pypi_0    pypi
babel                     2.14.0                   pypi_0    pypi
beautifulsoup4            4.12.3                   pypi_0    pypi
bitsandbytes              0.42.0                   pypi_0    pypi
blas                      1.0                         mkl  
bleach                    6.1.0                    pypi_0    pypi
blis                      0.7.11                   pypi_0    pypi
bzip2                     1.0.8                h5eee18b_5  
ca-certificates           2024.3.11            h06a4308_0  
catalogue                 2.0.10                   pypi_0    pypi
certifi                   2024.2.2                 pypi_0    pypi
cffi                      1.16.0                   pypi_0    pypi
charset-normalizer        3.3.2                    pypi_0    pypi
click                     8.1.7                    pypi_0    pypi
cloudpathlib              0.16.0                   pypi_0    pypi
comm                      0.2.2                    pypi_0    pypi
confection                0.1.4                    pypi_0    pypi
contourpy                 1.2.0                    pypi_0    pypi
cuda-cudart               12.1.105                      0    nvidia
cuda-cupti                12.1.105                      0    nvidia
cuda-libraries            12.1.0                        0    nvidia
cuda-nvrtc                12.1.105                      0    nvidia
cuda-nvtx                 12.1.105                      0    nvidia
cuda-opencl               12.4.127                      0    nvidia
cuda-runtime              12.1.0                        0    nvidia
cycler                    0.12.1                   pypi_0    pypi
cymem                     2.0.8                    pypi_0    pypi
datasets                  2.14.7                   pypi_0    pypi
debugpy                   1.8.1                    pypi_0    pypi
decorator                 5.1.1                    pypi_0    pypi
deepspeed                 0.14.0                   pypi_0    pypi
defusedxml                0.7.1                    pypi_0    pypi
dill                      0.3.7                    pypi_0    pypi
docker-pycreds            0.4.0                    pypi_0    pypi
editdistance              0.6.2                    pypi_0    pypi
einops                    0.8.0                    pypi_0    pypi
executing                 2.0.1                    pypi_0    pypi
fastjsonschema            2.19.1                   pypi_0    pypi
filelock                  3.13.1          py311h06a4308_0  
fonttools                 4.50.0                   pypi_0    pypi
fqdn                      1.5.1                    pypi_0    pypi
frozenlist                1.4.1                    pypi_0    pypi
fsspec                    2023.10.0                pypi_0    pypi
gcc                       5.4.0                         0    https://anaconda.org/brown-data-science/gcc/5.4.0/download
gitdb                     4.0.11                   pypi_0    pypi
gitpython                 3.1.42                   pypi_0    pypi
gmp                       6.2.1                h295c915_3  
gmpy2                     2.1.2           py311hc9b5ff0_0  
h11                       0.14.0                   pypi_0    pypi
hjson                     3.1.0                    pypi_0    pypi
httpcore                  1.0.5                    pypi_0    pypi
httpx                     0.27.0                   pypi_0    pypi
huggingface-hub           0.21.4                   pypi_0    pypi
hydra-core                1.3.2                    pypi_0    pypi
idna                      3.6                      pypi_0    pypi
iniconfig                 2.0.0                    pypi_0    pypi
intel-openmp              2023.1.0         hdb19cb5_46306  
ipykernel                 6.25.2                   pypi_0    pypi
ipython                   8.22.2                   pypi_0    pypi
ipywidgets                8.1.2                    pypi_0    pypi
isoduration               20.11.0                  pypi_0    pypi
jedi                      0.19.1                   pypi_0    pypi
jinja2                    3.1.3           py311h06a4308_0  
joblib                    1.3.2                    pypi_0    pypi
json5                     0.9.24                   pypi_0    pypi
jsonpointer               2.4                      pypi_0    pypi
jsonschema                4.21.1                   pypi_0    pypi
jsonschema-specifications 2023.12.1                pypi_0    pypi
jupyter                   1.0.0                    pypi_0    pypi
jupyter-client            8.6.1                    pypi_0    pypi
jupyter-console           6.6.3                    pypi_0    pypi
jupyter-core              5.7.2                    pypi_0    pypi
jupyter-events            0.10.0                   pypi_0    pypi
jupyter-lsp               2.2.4                    pypi_0    pypi
jupyter-server            2.13.0                   pypi_0    pypi
jupyter-server-terminals  0.5.3                    pypi_0    pypi
jupyterlab                4.1.5                    pypi_0    pypi
jupyterlab-pygments       0.3.0                    pypi_0    pypi
jupyterlab-server         2.25.4                   pypi_0    pypi
jupyterlab-widgets        3.0.10                   pypi_0    pypi
kiwisolver                1.4.5                    pypi_0    pypi
langcodes                 3.3.0                    pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1  
libcublas                 12.1.0.26                     0    nvidia
libcufft                  11.0.2.4                      0    nvidia
libcufile                 1.9.1.3                       0    nvidia
libcurand                 10.3.5.147                    0    nvidia
libcusolver               11.4.4.55                     0    nvidia
libcusparse               12.0.2.55                     0    nvidia
libffi                    3.4.4                h6a678d5_0  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libnpp                    12.0.2.50                     0    nvidia
libnvjitlink              12.1.105                      0    nvidia
libnvjpeg                 12.1.1.14                     0    nvidia
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
lightning                 2.2.1                    pypi_0    pypi
lightning-utilities       0.11.2                   pypi_0    pypi
lion-pytorch              0.1.4                    pypi_0    pypi
llvm-openmp               14.0.6               h9e868ea_0  
loralib                   0.1.2                    pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                2.1.5                    pypi_0    pypi
matplotlib                3.8.4                    pypi_0    pypi
matplotlib-inline         0.1.6                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
mistune                   3.0.2                    pypi_0    pypi
mkl                       2023.1.0         h213fc3f_46344  
mpc                       1.1.0                h10f8cd9_1  
mpfr                      4.0.2                hb69a4c5_1  
mpmath                    1.3.0           py311h06a4308_0  
multidict                 6.0.5                    pypi_0    pypi
multiprocess              0.70.15                  pypi_0    pypi
murmurhash                1.0.10                   pypi_0    pypi
nbclient                  0.10.0                   pypi_0    pypi
nbconvert                 7.16.3                   pypi_0    pypi
nbformat                  5.10.3                   pypi_0    pypi
ncurses                   6.4                  h6a678d5_0  
nest-asyncio              1.6.0                    pypi_0    pypi
networkx                  3.2.1                    pypi_0    pypi
ninja                     1.11.1.1                 pypi_0    pypi
nltk                      3.8.1                    pypi_0    pypi
notebook                  7.1.2                    pypi_0    pypi
notebook-shim             0.2.4                    pypi_0    pypi
numpy                     1.26.0                   pypi_0    pypi
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-nccl-cu12          2.18.1                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.99                  pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
omegaconf                 2.3.0                    pypi_0    pypi
openssl                   3.0.13               h7f8727e_0  
overrides                 7.7.0                    pypi_0    pypi
packaging                 24.0                     pypi_0    pypi
pandas                    2.2.2                    pypi_0    pypi
pandocfilters             1.5.1                    pypi_0    pypi
parso                     0.8.3                    pypi_0    pypi
pathtools                 0.1.2                    pypi_0    pypi
peft                      0.5.0                    pypi_0    pypi
pexpect                   4.9.0                    pypi_0    pypi
pillow                    10.3.0                   pypi_0    pypi
pip                       24.0                     pypi_0    pypi
platformdirs              4.2.0                    pypi_0    pypi
pluggy                    1.5.0                    pypi_0    pypi
preshed                   3.0.9                    pypi_0    pypi
prometheus-client         0.20.0                   pypi_0    pypi
prompt-toolkit            3.0.43                   pypi_0    pypi
protobuf                  4.25.3                   pypi_0    pypi
psutil                    5.9.8                    pypi_0    pypi
ptyprocess                0.7.0                    pypi_0    pypi
pure-eval                 0.2.2                    pypi_0    pypi
py-cpuinfo                9.0.0                    pypi_0    pypi
pyarrow                   15.0.2                   pypi_0    pypi
pyarrow-hotfix            0.6                      pypi_0    pypi
pycparser                 2.22                     pypi_0    pypi
pydantic                  2.6.4                    pypi_0    pypi
pydantic-core             2.16.3                   pypi_0    pypi
pygments                  2.17.2                   pypi_0    pypi
pynvml                    11.5.0                   pypi_0    pypi
pyparsing                 3.1.2                    pypi_0    pypi
pytest                    8.2.0                    pypi_0    pypi
python                    3.11.8               h955ad1f_0  
python-dateutil           2.9.0.post0              pypi_0    pypi
python-json-logger        2.0.7                    pypi_0    pypi
pytorch                   2.2.2           py3.11_cuda12.1_cudnn8.9.2_0    pytorch
pytorch-cuda              12.1                 ha16c6d3_5    pytorch
pytorch-lightning         2.2.1                    pypi_0    pypi
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2024.1                   pypi_0    pypi
pyyaml                    6.0.1           py311h5eee18b_0  
pyzmq                     25.1.2                   pypi_0    pypi
qtconsole                 5.5.1                    pypi_0    pypi
qtpy                      2.4.1                    pypi_0    pypi
readline                  8.2                  h5eee18b_0  
referencing               0.34.0                   pypi_0    pypi
regex                     2023.12.25               pypi_0    pypi
requests                  2.31.0                   pypi_0    pypi
rfc3339-validator         0.1.4                    pypi_0    pypi
rfc3986-validator         0.1.1                    pypi_0    pypi
rich                      13.6.0                   pypi_0    pypi
rpds-py                   0.18.0                   pypi_0    pypi
safetensors               0.4.2                    pypi_0    pypi
scikit-learn              1.4.1.post1              pypi_0    pypi
scipy                     1.13.0                   pypi_0    pypi
seaborn                   0.13.2                   pypi_0    pypi
send2trash                1.8.2                    pypi_0    pypi
sentence-transformers     2.6.1                    pypi_0    pypi
sentry-sdk                1.43.0                   pypi_0    pypi
setproctitle              1.3.3                    pypi_0    pypi
setuptools                68.2.2          py311h06a4308_0  
six                       1.16.0                   pypi_0    pypi
smart-open                6.4.0                    pypi_0    pypi
smmap                     5.0.1                    pypi_0    pypi
sniffio                   1.3.1                    pypi_0    pypi
soupsieve                 2.5                      pypi_0    pypi
spacy                     3.7.4                    pypi_0    pypi
spacy-legacy              3.0.12                   pypi_0    pypi
spacy-loggers             1.0.5                    pypi_0    pypi
sqlite                    3.41.2               h5eee18b_0  
srsly                     2.4.8                    pypi_0    pypi
stack-data                0.6.3                    pypi_0    pypi
sympy                     1.12            py311h06a4308_0  
tbb                       2021.8.0             hdb19cb5_0  
terminado                 0.18.1                   pypi_0    pypi
thinc                     8.2.3                    pypi_0    pypi
threadpoolctl             3.4.0                    pypi_0    pypi
tinycss2                  1.2.1                    pypi_0    pypi
tk                        8.6.12               h1ccaba5_0  
tokenizers                0.15.2                   pypi_0    pypi
torch                     2.1.0                    pypi_0    pypi
torchdata                 0.7.1                    pypi_0    pypi
torchmetrics              1.3.2                    pypi_0    pypi
torchtriton               2.2.0                     py311    pytorch
tornado                   6.4                      pypi_0    pypi
tqdm                      4.66.1                   pypi_0    pypi
traitlets                 5.14.2                   pypi_0    pypi
transformers              4.40.0.dev0              pypi_0    pypi
triton                    2.0.0.dev20221202          pypi_0    pypi
trl                       0.7.1                    pypi_0    pypi
typer                     0.9.4                    pypi_0    pypi
types-python-dateutil     2.9.0.20240316           pypi_0    pypi
typing-extensions         4.10.0                   pypi_0    pypi
typing_extensions         4.9.0           py311h06a4308_1  
tzdata                    2024.1                   pypi_0    pypi
uri-template              1.3.0                    pypi_0    pypi
urllib3                   2.2.1                    pypi_0    pypi
wandb                     0.15.12                  pypi_0    pypi
wasabi                    1.1.2                    pypi_0    pypi
wcwidth                   0.2.13                   pypi_0    pypi
weasel                    0.3.4                    pypi_0    pypi
webcolors                 1.13                     pypi_0    pypi
webencodings              0.5.1                    pypi_0    pypi
websocket-client          1.7.0                    pypi_0    pypi
wheel                     0.43.0                   pypi_0    pypi
widgetsnbextension        4.0.10                   pypi_0    pypi
xxhash                    3.4.1                    pypi_0    pypi
xz                        5.4.6                h5eee18b_0  
yaml                      0.2.5                h7b6447c_0  
yarl                      1.9.4                    pypi_0    pypi
zlib                      1.2.13               h5eee18b_0
FractureSR commented 2 months ago

Exactly same problem here. Is there any update with this issue?

Crazy-LittleBoy commented 2 months ago

this works for me: https://github.com/Lightning-AI/pytorch-lightning/discussions/17878

maybe you can also try the following: https://github.com/huggingface/transformers/issues/28770