huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.97k stars 27k forks source link

NotImplementedError: Cannot copy out of meta tensor; no data! with Multi-node training #26971

Closed ari9dam closed 1 year ago

ari9dam commented 1 year ago

System Info

A100
Cuda 11.7
PyTorch 2.0.1
# This dependencies file is produced by 'conda export'
{
  "channels": [
    "pytorch",
    "defaults"
  ],
  "dependencies": [
    "_libgcc_mutex=0.1=main",
    "_openmp_mutex=5.1=1_gnu",
    "ca-certificates=2023.01.10=h06a4308_0",
    "ld_impl_linux-64=2.38=h1181459_1",
    "libffi=3.4.4=h6a678d5_0",
    "libgcc-ng=11.2.0=h1234567_1",
    "libgomp=11.2.0=h1234567_1",
    "libstdcxx-ng=11.2.0=h1234567_1",
    "magma-cuda117=2.6.1=1",
    "ncurses=6.4=h6a678d5_0",
    "openssl=1.1.1t=h7f8727e_0",
    "pip=23.0.1=py38h06a4308_0",
    "python=3.8.16=h7a1cb2a_3",
    "readline=8.2=h5eee18b_0",
    "sqlite=3.41.2=h5eee18b_0",
    "tk=8.6.12=h1ccaba5_0",
    "xz=5.4.2=h5eee18b_0",
    "zlib=1.2.13=h5eee18b_0",
    {
      "pip": [
        "absl-py==2.0.0",
        "accelerate==0.24.0.dev0",
        "adal==1.2.7",
        "aiofiles==23.1.0",
        "aiohttp==3.8.4",
        "aiosignal==1.3.1",
        "altair==5.1.2",
        "antlr4-python3-runtime==4.9.3",
        "anyio==3.7.1",
        "apex==0.1",
        "applicationinsights==0.11.10",
        "argcomplete==2.1.2",
        "asttokens==2.4.0",
        "async-timeout==4.0.2",
        "attrs==23.1.0",
        "azure-common==1.1.28",
        "azure-core==1.26.4",
        "azure-graphrbac==0.61.1",
        "azure-identity==1.13.0",
        "azure-mgmt-authorization==3.0.0",
        "azure-mgmt-containerregistry==10.2.0",
        "azure-mgmt-core==1.4.0",
        "azure-mgmt-keyvault==10.3.0",
        "azure-mgmt-resource==22.0.0",
        "azure-mgmt-storage==21.0.0",
        "azure-ml==0.0.1",
        "azure-ml-component==0.9.18.post2",
        "azure-storage-blob==12.13.0",
        "azureml-automl-common-tools==1.51.0",
        "azureml-automl-core==1.51.0.post1",
        "azureml-contrib-services==1.51.0",
        "azureml-core==1.51.0",
        "azureml-dataprep==4.10.9",
        "azureml-dataprep-native==38.0.0",
        "azureml-dataprep-rslex==2.17.12",
        "azureml-dataset-runtime==1.51.0",
        "azureml-defaults==1.51.0",
        "azureml-inference-server-http==0.8.4.1",
        "azureml-mlflow==1.51.0",
        "azureml-pipeline==1.51.0",
        "azureml-pipeline-core==1.51.0",
        "azureml-pipeline-steps==1.51.0",
        "azureml-sdk==1.51.0",
        "azureml-telemetry==1.51.0",
        "azureml-train-automl-client==1.51.0.post1",
        "azureml-train-core==1.51.0",
        "azureml-train-restclients-hyperdrive==1.51.0",
        "backcall==0.2.0",
        "backports-tempfile==1.0",
        "backports-weakref==1.0.post1",
        "bcrypt==4.0.1",
        "bytecode==0.15.1",
        "cachetools==5.3.0",
        "cerberus==1.3.4",
        "certifi==2023.5.7",
        "cffi==1.15.1",
        "charset-normalizer==3.1.0",
        "click==8.1.7",
        "cloudpickle==2.2.1",
        "cmake==3.26.3",
        "coloredlogs==15.0.1",
        "comm==0.1.4",
        "contextlib2==21.6.0",
        "coverage==6.3.1",
        "cryptography==40.0.2",
        "cycler==0.12.1",
        "databricks-cli==0.18.0",
        "datasets==2.14.5",
        "debugpy==1.6.7.post1",
        "decorator==5.1.1",
        "deepspeed==0.9.1",
        "dill==0.3.7",
        "distro==1.8.0",
        "docker==6.1.3",
        "dotnetcore2==3.1.23",
        "einops==0.7.0",
        "entrypoints==0.4",
        "evaluate==0.4.1",
        "exceptiongroup==1.1.3",
        "executing==2.0.0",
        "fairscale==0.4.13",
        "fastapi==0.104.0",
        "ffmpy==0.3.1",
        "filelock==3.12.0",
        "flash-attn==2.3.2",
        "flask==2.2.5",
        "flask-cors==3.0.10",
        "flatbuffers==23.5.9",
        "fonttools==4.43.1",
        "frozenlist==1.3.3",
        "fsspec==2023.5.0",
        "fusepy==3.0.1",
        "gitdb==4.0.11",
        "gitpython==3.1.40",
        "google-api-core==2.11.0",
        "google-auth==2.19.0",
        "google-auth-oauthlib==0.4.6",
        "googleapis-common-protos==1.59.0",
        "gradio==3.23.0",
        "grpcio==1.59.0",
        "gunicorn==20.1.0",
        "h11==0.14.0",
        "h5py==3.8.0",
        "hjson==3.1.0",
        "horovod==0.24.2",
        "httpcore==0.18.0",
        "httpx==0.25.0",
        "huggingface-hub==0.17.3",
        "humanfriendly==10.0",
        "idna==3.4",
        "igraph==0.10.4",
        "importlib-metadata==6.6.0",
        "importlib-resources==6.1.0",
        "inference-schema==1.5.1",
        "inflector==3.1.0",
        "iniconfig==2.0.0",
        "intel-openmp==2021.4.0",
        "ipykernel==6.25.2",
        "ipython==8.12.3",
        "isodate==0.6.1",
        "itsdangerous==2.1.2",
        "jedi==0.19.1",
        "jeepney==0.8.0",
        "jinja2==3.1.2",
        "jmespath==1.0.1",
        "joblib==1.3.2",
        "jsonlines==4.0.0",
        "jsonpickle==3.0.2",
        "jsonschema==4.19.1",
        "jsonschema-specifications==2023.7.1",
        "jupyter-client==8.4.0",
        "jupyter-core==5.4.0",
        "kiwisolver==1.4.5",
        "knack==0.10.1",
        "lightning-utilities==0.8.0",
        "linkify-it-py==2.0.2",
        "lit==16.0.5",
        "lxml==4.9.2",
        "markdown==3.5",
        "markdown-it-py==2.2.0",
        "markdown2==2.4.10",
        "markupsafe==2.1.2",
        "matplotlib==3.5.3",
        "matplotlib-inline==0.1.6",
        "mdit-py-plugins==0.3.3",
        "mdurl==0.1.2",
        "mkl==2021.4.0",
        "mkl-include==2021.4.0",
        "mlflow-skinny==2.7.1",
        "mpi4py==3.1.1",
        "mpmath==1.3.0",
        "msal==1.22.0",
        "msal-extensions==1.0.0",
        "msccl==2.3.0",
        "msrest==0.7.1",
        "msrestazure==0.6.4",
        "multidict==6.0.4",
        "multiprocess==0.70.15",
        "ndg-httpsclient==0.5.1",
        "nebulaml==0.16.2",
        "nest-asyncio==1.5.6",
        "networkx==3.1",
        "ninja==1.10.2",
        "nltk==3.8.1",
        "numpy==1.22.2",
        "oauthlib==3.2.2",
        "omegaconf==2.3.0",
        "onnx==1.14.0",
        "onnxruntime-gpu==1.16.1",
        "onnxruntime-training==1.14.1",
        "opencensus==0.11.2",
        "opencensus-context==0.1.3",
        "opencensus-ext-azure==1.1.9",
        "opencensus-ext-logging==0.1.1",
        "orjson==3.9.9",
        "packaging==23.0",
        "pandas==2.0.3",
        "paramiko==3.3.1",
        "parso==0.8.3",
        "pathspec==0.11.2",
        "pexpect==4.8.0",
        "pickleshare==0.7.5",
        "pillow==9.5.0",
        "pkginfo==1.9.6",
        "pkgutil-resolve-name==1.3.10",
        "platformdirs==3.11.0",
        "pluggy==1.0.0",
        "portalocker==2.7.0",
        "prompt-toolkit==3.0.39",
        "protobuf==3.20.3",
        "psutil==5.8.0",
        "ptyprocess==0.7.0",
        "pure-eval==0.2.2",
        "py==1.11.0",
        "py-cpuinfo==5.0.0",
        "py-spy==0.3.12",
        "pyarrow==9.0.0",
        "pyasn1==0.5.0",
        "pyasn1-modules==0.3.0",
        "pybind11==2.11.1",
        "pycparser==2.21",
        "pydantic==1.10.8",
        "pydash==7.0.6",
        "pydub==0.25.1",
        "pygments==2.16.1",
        "pyjwt==2.7.0",
        "pynacl==1.5.0",
        "pyopenssl==23.2.0",
        "pyparsing==3.1.1",
        "pysocks==1.7.1",
        "pytest==7.1.0",
        "pytest-mpi==0.6",
        "python-dateutil==2.8.2",
        "python-multipart==0.0.6",
        "pytorch-lightning==1.9.3",
        "pytz==2023.3.post1",
        "pyyaml==6.0",
        "pyzmq==25.1.1",
        "referencing==0.30.2",
        "regex==2023.10.3",
        "requests==2.31.0",
        "requests-oauthlib==1.3.1",
        "responses==0.18.0",
        "rouge-score==0.1.2",
        "rpds-py==0.10.6",
        "rsa==4.9",
        "ruamel-yaml==0.17.16",
        "ruamel-yaml-clib==0.2.8",
        "safetensors==0.4.0",
        "scipy==1.7.3",
        "secretstorage==3.3.3",
        "semantic-version==2.10.0",
        "sentencepiece==0.1.99",
        "setuptools==67.6.0",
        "six==1.16.0",
        "smmap==5.0.1",
        "sniffio==1.3.0",
        "sqlparse==0.4.4",
        "stack-data==0.6.3",
        "starlette==0.27.0",
        "supervisor==4.2.5",
        "svgwrite==1.4.3",
        "sympy==1.12",
        "tabulate==0.9.0",
        "tbb==2021.9.0",
        "tensorboard==2.11.2",
        "tensorboard-data-server==0.6.1",
        "tensorboard-plugin-wit==1.8.1",
        "texttable==1.6.7",
        "timm==0.9.7",
        "tokenizers==0.14.1",
        "toml==0.10.2",
        "tomli==2.0.1",
        "toolz==0.12.0",
        "torch==2.0.1+cu117",
        "torch-nebula==0.16.2",
        "torch-ort==1.14.0",
        "torch-tb-profiler==0.4.3",
        "torchaudio==2.0.2+cu117",
        "torchmetrics==0.11.3",
        "torchsnapshot==0.1.0",
        "torchvision==0.15.2+cu117",
        "tornado==6.3.3",
        "tqdm==4.62.3",
        "traitlets==5.11.2",
        "transformers==4.35.0.dev0",
        "triton==2.0.0",
        "tutel==0.1",
        "typing-extensions==4.8.0",
        "tzdata==2023.3",
        "uc-micro-py==1.0.2",
        "urllib3==1.26.16",
        "uvicorn==0.23.2",
        "wavedrom==2.0.3.post3",
        "wcwidth==0.2.8",
        "websocket-client==1.6.4",
        "websockets==11.0.3",
        "werkzeug==3.0.0",
        "wheel==0.40.0",
        "wrapt==1.12.1",
        "xxhash==3.4.1",
        "yarl==1.9.2",
        "z3-solver==4.12.2.0",
        "zipp==3.15.0"
      ]
    }
  ],
  "name": "ptca",
  "prefix": "/opt/conda/envs/ptca"
}

Who can help?

@muellerz @pacman100

Information

Tasks

Reproduction

   model = transformers.AutoModelForCausalLM.from_pretrained(
          "mistralai/Mistral-7B-v0.1",
          torch_dtype=torch.bfloat16,
          use_flash_attention_2=True
      )

    trainer = Trainer(model=model,
                    tokenizer=tokenizer,
                    args=training_args,
                    compute_metrics = None,
                    **data_module)

    trainer.train()

The training job works on A100 with 1 node and 8 GPUs. It fails when job uses more than 1 node with the error:

File "./trainer.py", line 206, in <module>
    train()
  File "./trainer.py", line 157, in train
    model = transformers.AutoModelForCausalLM.from_pretrained(
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py", line 565, in from_pretrained
    return model_class.from_pretrained(
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3333, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3723, in _load_pretrained_model
    new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py", line 744, in _load_state_dict_into_meta_model
    set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/accelerate/utils/modeling.py", line 317, in set_module_tensor_to_device
    new_value = value.to(device)
NotImplementedError: Cannot copy out of meta tensor; no data!

Expected behavior

No error

ari9dam commented 1 year ago

Relevant: https://github.com/huggingface/transformers/pull/26631 @pacman100

ari9dam commented 1 year ago
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
  fsdp_use_orig_params: true
main_training_function: main
mixed_precision: bf16
num_machines: 2
num_processes: 16
rdzv_backend: static
same_network: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
pacman100 commented 1 year ago

Hello @ari9dam,

The PR you tagged above should resolve this issue. Please recreate the FSDP config via accelerate config command and answer False for RAM efficient loading of the pretrained model.

ari9dam commented 1 year ago

Thank you that solved it. I've one more question: @pacman100 model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, use_flash_attention_2=True )

should I pass torch dtype here while loading the model? I'm using bf16 in accelerate config. I get warnings:

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with model.to('cuda').

Muennighoff commented 10 months ago

also had this issue and fixed it by changing

        if (
            is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0
        ) or (is_fsdp_enabled() and not is_local_dist_rank_0()):
            map_location = "meta"

to

        if (
            (is_deepspeed_zero3_enabled() or is_fsdp_enabled())
            and torch.distributed.is_initialized()
            and (torch.distributed.get_rank() % 8 != 0)
        ):
            map_location = "meta"

here https://github.com/huggingface/transformers/blob/29e7a1e1834f331a4916853ecd58549ed78235d6/src/transformers/modeling_utils.py#L512 (this is for 8 gpus per node; for 4 gpus per node should be 4 etc)