UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.23k stars 2.47k forks source link

Error in Fully Sharded Data Parallelism (FSDP) set up #2931

Open MohammedAlhajji opened 1 month ago

MohammedAlhajji commented 1 month ago

Trying to finetune a model whose max seq length is 8k, BAAI/bge-m3. I'm trying to finetune on some retrieval task. Here's my trainer set up

    model = SentenceTransformer(model_id, device="cuda")
    loss = MultipleNegativesRankingLoss(model)

    args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir=output_path,
        fsdp=["full_shard", "auto_wrap"],
        fsdp_config={"transformer_layer_cls_to_wrap": "BertLayer"},
        # Optional training parameters:
        num_train_epochs=3,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        learning_rate=lr,
        warmup_ratio=warmup_ratio,
        weight_decay=weight_decay,
        fp16=False,
        bf16=True,
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicates
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=200,
        save_strategy="steps",
        save_steps=200,
        save_total_limit=2,
        logging_steps=200,
        disable_tqdm=False,
    )

I run it via torchrun --nproc_per_node=8 train.py

When running though, I get the following error:

Traceback (most recent call last):
  File "/eph/nvme0/azureml/cr/j/b62bd9e283c3429ea92836088ca33b14/exe/wd/train.py", line 291, in <module>
    main(
  File "/eph/nvme0/azureml/cr/j/b62bd9e283c3429ea92836088ca33b14/exe/wd/train.py", line 213, in main
    trainer.train()
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/trainer.py", line 2085, in _inner_training_loop
    self.model = self.accelerator.prepare(self.model)
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/accelerate/accelerator.py", line 1326, in prepare
    result = tuple(
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/accelerate/accelerator.py", line 1327, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/accelerate/accelerator.py", line 1200, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/accelerate/accelerator.py", line 1468, in prepare_model
    self.state.fsdp_plugin.set_auto_wrap_policy(model)
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/accelerate/utils/dataclasses.py", line 1554, in set_auto_wrap_policy
    raise ValueError(f"Could not find the transformer layer class {layer_class} in the model.")
ValueError: Could not find the transformer layer class BertLayer in the model.

Any thoughts on what's wrong in my code?

MohammedAlhajji commented 1 month ago

if I do a full_shard I get the following error

Traceback (most recent call last):
  File "/eph/nvme0/azureml/cr/j/e783517ada544c6e901713a9aef8f300/exe/wd/train.py", line 291, in <module>
    main(
  File "/eph/nvme0/azureml/cr/j/e783517ada544c6e901713a9aef8f300/exe/wd/train.py", line 213, in main
    trainer.train()
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/trainer.py", line 2085, in _inner_training_loop
    self.model = self.accelerator.prepare(self.model)
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/accelerate/accelerator.py", line 1326, in prepare
    result = tuple(
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/accelerate/accelerator.py", line 1327, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/accelerate/accelerator.py", line 1200, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/accelerate/accelerator.py", line 1484, in prepare_model
    model = FSDP(model, **kwargs)
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 503, in __init__
    _init_param_handle_from_module(
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 574, in _init_param_handle_from_module
    state.compute_device = _get_compute_device(
  File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 1037, in _get_compute_device
    raise ValueError(
ValueError: Inconsistent compute device and `device_id` on rank 3: cuda:0 vs cuda:3
tomaarsen commented 1 month ago

Hello!

As for the

ValueError: Could not find the transformer layer class BertLayer in the model.

error, this is because FSDP is looking for which layers to share across devices, and there's no BertLayer in BGE-M3 because it's based on XLM-RoBERTa instead. I suspect you have to use XLMRobertaLayer, so:

        fsdp_config={"transformer_layer_cls_to_wrap": "XLMRobertaLayer"},

Let me know if that gets you a bit further. FSDP wasn't fully tested because DDP is faster with most models: the primary use case with FSDP is if the model itself is so big that sharing it across devices allows you to get a much higher batch size. At least, that is my understanding.

MohammedAlhajji commented 1 month ago

I just noticed that couple of hours ago and made the fix. This fixed the layer issue. I should have been more careful.

I didn't want to do FSDP but for some reason I kept getting cuda out of memory while training. My model is only 500M parameters and my dataset is a retrieval datasets with some long articles. My max_seq_length is 8192. If i truncate the length of the articles to 300 words, I edge close to my memory limit of 80GB(A100). When increasing the number of words to 1000 words, I get the cuda out of memory issue. My suspicion is that it's due to the fact that attention scales ~quadratically with the seq_length. This is why I thought FSDP is what I should do. Is there anything else I could do other than FSDP in my case?

In anyway, I tried to fix the inconsistent compute device issue with the following but ran into another error:

local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
model = model.to(f"cuda:{local_rank}")

and running it with torchrun --nnodes=1 --nproc_per_node=8 --node_rank=0 train.py

However, then I received the error:

[rank0]:   File "/eph/nvme0/azureml/cr/j/ac16e93f7de44de883a5279b67300a74/exe/wd/train.py", line 322, in <module>
[rank0]:     main(
[rank0]:   File "/eph/nvme0/azureml/cr/j/ac16e93f7de44de883a5279b67300a74/exe/wd/train.py", line 247, in main
[rank0]:     trainer.train()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 1938, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2279, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 3318, in training_step
[rank0]:     loss = self.compute_loss(model, inputs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/sentence_transformers/trainer.py", line 344, in compute_loss
[rank0]:     loss = loss_fn(features, labels)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/sentence_transformers/losses/MultipleNegativesRankingLoss.py", line 99, in forward
[rank0]:     reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/sentence_transformers/losses/MultipleNegativesRankingLoss.py", line 99, in <listcomp>
[rank0]:     reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/accelerate/utils/operations.py", line 820, in forward
[rank0]:     return model_forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/accelerate/utils/operations.py", line 808, in __call__
[rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/sentence_transformers/SentenceTransformer.py", line 668, in forward
[rank0]:     input = module(input, **module_kwargs)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/sentence_transformers/models/Transformer.py", line 118, in forward
[rank0]:     output_states = self.auto_model(**trans_features, **kwargs, return_dict=False)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/transformers/models/xlm_roberta/modeling_xlm_roberta.py", line 827, in forward
[rank0]:     embedding_output = self.embeddings(
[rank0]:                        ^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/transformers/models/xlm_roberta/modeling_xlm_roberta.py", line 116, in forward
[rank0]:     inputs_embeds = self.word_embeddings(input_ids)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 164, in forward
[rank0]:     return F.embedding(
[rank0]:            ^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/functional.py", line 2267, in embedding
[rank0]:     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: 'weight' must be 2-D
MohammedAlhajji commented 1 month ago

just an update, I manage to fix my cuda out of memory issue. Here's the fix.

In SentenceTransformer documentation:

attn_implementation: The attention implementation to use in the model (if relevant). Can be any of “eager” (manual implementation of the attention), “sdpa” (using F.scaled_dot_product_attention), or “flash_attention_2” (using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual “eager” implementation.

This is true for most models, but not for XLMRobertaModels. For XLMRobertaModels, a PR in the transformers library was recently merged but has not been included in the releases as of yet. So you need to install it from the github repo. That will reduce memory requirements heavily

We can close this issue for now, or I can test different configuration of FSDP just for the sake of getting it right, if needed