ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.05k stars 5.59k forks source link

[Train] Ray distributed training hangs in `loss.backward()` call #43753

Closed arunppsg closed 6 months ago

arunppsg commented 6 months ago

What happened + What you expected to happen

Training on a single-node with a single-gpu works but when I scale the training to multi-node multi-gpu, the training hangs (probably in loss.backward() call).

Log details which I observed:

(RayTrainWorker pid=2285, ip=172.31.1.71) ip-172-31-1-71:2285:2487 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to ^lo,docker,veth
(RayTrainWorker pid=2285, ip=172.31.1.71) ip-172-31-1-71:2285:2487 [0] NCCL INFO Bootstrap : Using ens3:172.31.1.71<0>
(RayTrainWorker pid=2285, ip=172.31.1.71) ip-172-31-1-71:2285:2487 [0] NCCL INFO NET/Plugin: Failed to find ncclNetPlugin_v7 symbol.
(RayTrainWorker pid=2285, ip=172.31.1.71) ip-172-31-1-71:2285:2487 [0] NCCL INFO NET/Plugin: Loaded net plugin AWS Libfabric (v6)
(RayTrainWorker pid=2285, ip=172.31.1.71) ip-172-31-1-71:2285:2487 [0] NCCL INFO NET/Plugin: Failed to find ncclCollNetPlugin_v7 symbol.
(RayTrainWorker pid=2285, ip=172.31.1.71) ip-172-31-1-71:2285:2487 [0] NCCL INFO NET/Plugin: Failed to find ncclCollNetPlugin symbol (>= v5). ncclCollNetPlugin symbols v4 and lower are not supported.
(RayTrainWorker pid=2285, ip=172.31.1.71) ip-172-31-1-71:2285:2487 [0] NCCL INFO cudaDriverVersion 12000
(RayTrainWorker pid=2285, ip=172.31.1.71) NCCL version 2.19.3+cuda12.3
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO NET/OFI Initializing aws-ofi-nccl 1.7.4-aws
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO NET/OFI Using CUDA runtime version 12010
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO NET/OFI Configuring AWS-specific options
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO NET/OFI Setting provider_filter to efa
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO NET/OFI Setting FI_EFA_FORK_SAFE environment variable to 1
(RayTrainWorker pid=2286, ip=172.31.1.71)
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] configure_nvls_option:293 NCCL WARN NET/OFI Could not find ncclGetVersion symbol
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO NET/OFI Disabling NVLS support due to NCCL version 0
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO NET/OFI Internode latency set at 150.0 us
(RayTrainWorker pid=2286, ip=172.31.1.71)
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] nccl_net_ofi_init:1239 NCCL WARN NET/OFI aws-ofi-nccl initialization failed
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO net.cc:54 -> 2
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO NET/IB : No device found.
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO NET/Socket : Using [0]ens3:172.31.1.71<0>
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO Using non-device net plugin version 0
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO Using network Socket
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO comm 0x7f4c60d183c0 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId 1e0 commId 0xfac3265df608750c - Init START
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO NCCL_NVLS_ENABLE set by environment to 0.
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] -1/-1/-1->1->0
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO P2P Chunksize set to 131072
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO Channel 00 : 1[1] -> 0[0] via SHM/direct/direct
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO Channel 01 : 1[1] -> 0[0] via SHM/direct/direct
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO Connected all rings
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO Connected all trees
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
(RayTrainWorker pid=2286, ip=172.31.1.71) ip-172-31-1-71:2286:2533 [1] NCCL INFO comm 0x7f4c60d183c0 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId 1e0 commId 0xfac3265df608750c - Init COMPLETE
(RayTrainWorker pid=2286, ip=172.31.1.71) [rank1]:[W Utils.hpp:106] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function getCvarString)
(RayTrainWorker pid=2285, ip=172.31.1.71) ip-172-31-1-71:2285:2532 [0] NCCL INFO Channel 00/02 :    0   1
(RayTrainWorker pid=2285, ip=172.31.1.71) ip-172-31-1-71:2285:2532 [0] NCCL INFO Channel 01/02 :    0   1

[2m(pid=2430, ip=172.31.1.71) Running 0:   0%|          | 0/200 [00:00<?, ?it/s]
(pid=2430, ip=172.31.1.71) Running: 0.0/129.0 CPU, 0.0/6.0 GPU, 0.0 MiB/72.94 GiB object_store_memory:   0%|          | 0/200 [00:00<?, ?it/s]

...

[2m(pid=2430, ip=172.31.1.71) Running: 0.0/129.0 CPU, 0.0/6.0 GPU, 0.05 MiB/72.94 GiB object_store_memory:  44%|████▎     | 87/200 [10:27<00:00, 143.66it/s]2024-03-06 16:41:23,783    ERROR tune_controller.py:1374 -- Trial task failed for trial TorchTrainer_4b358_00000
Traceback (most recent call last):
  File "/home/ubuntu/.local/lib/python3.8/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/ray/_private/worker.py", line 2624, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(DistBackendError): ray::_Inner.train() (pid=2214, ip=172.31.1.71, actor_id=5ee8ecf38be7688e2039c8b406000000, repr=TorchTrainer)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/ray/tune/trainable/trainable.py", line 342, in train
    raise skipped from exception_cause(skipped)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/ray/train/_internal/utils.py", line 43, in check_for_failure
    ray.get(object_ref)
ray.exceptions.RayTaskError(DistBackendError): ray::_RayTrainWorker__execute.get_next() (pid=2286, ip=172.31.1.71, actor_id=141493b6121bd3432184cebc06000000, repr=<ray.train._internal.worker_group.RayTrainWorker object at 0x7f4d85c015e0>)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/ray/train/_internal/worker_group.py", line 33, in __execute
    raise skipped from exception_cause(skipped)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/ray/train/_internal/utils.py", line 118, in discard_return_wrapper
    train_func(*args, **kwargs)
  File "train2.py", line 50, in train_loop_per_worker
    loss.backward()
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.distributed.DistBackendError: NCCL communicator was aborted on rank 1.

Versions / Dependencies

Ray: 2.9.2 OS: Ubuntu 20.04 PyTorch: 2.2.0+cu121

Reproduction script

import torch
from transformers import RobertaForMaskedLM, RobertaConfig
from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast
from transformers.data.data_collator import DataCollatorForLanguageModeling
import ray
from ray import train
import ray.train.torch
from ray.train.torch import TorchTrainer, TorchConfig
from ray.train import ScalingConfig
from ray.runtime_env import RuntimeEnv

runtime_env = RuntimeEnv(env_vars={"NCCL_DEBUG": "INFO", "TORCH_DISTRIBUTED_DETAIL": "DEBUG"})
ray.init(runtime_env=runtime_env)

def train_loop_per_worker(config):
    device = ray.train.torch.get_device()

    tokenizer_path = 'seyonec/PubChem10M_SMILES_BPE_60k'
    tokenizer = RobertaTokenizerFast.from_pretrained(tokenizer_path)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
    batch_size = 32 

    model = RobertaForMaskedLM(RobertaConfig(vocab_size=tokenizer.vocab_size))
    model = ray.train.torch.prepare_model(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    train_data_shard = train.get_dataset_shard("train")
    train_dataloader = train_data_shard.iter_batches(batch_size=batch_size)

    for epoch in range(0, config["num_epochs"]):
        for i, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            inputs = batch['smiles']
            tokens = tokenizer(inputs.tolist(),
                               padding=True,
                               return_tensors="pt")
            inputs, labels = data_collator.torch_mask_tokens(
                tokens['input_ids'])
            inputs = {
                'input_ids': inputs.to(device),
                'labels': labels.to(device),
                'attention_mask': tokens['attention_mask'].to(device),
            }
            outputs = model(**inputs)
            loss = outputs.get("loss")
            loss.backward()
            optimizer.step()

            if i % 20 == 0:
                ray.train.report(metrics={"loss": loss.detach().cpu().item()})

if __name__ == '__main__':
    use_gpu = True
    num_workers = 2

    train_dataset = ray.data.from_items([{'smiles': 'Cc1[nH]ccc1C(=O)N[C@@H](C)CNC(=O)c1cccc2[nH]ccc21'}] * 1000)
    train_loop_config = {"num_epochs": 5}

    torch_config = TorchConfig(backend='nccl', timeout_s=600)
    trainer = TorchTrainer(train_loop_per_worker=train_loop_per_worker,
                           train_loop_config=train_loop_config,
                           datasets={"train": train_dataset},
                           scaling_config=ScalingConfig(
                               num_workers=num_workers, use_gpu=use_gpu,
                               resources_per_worker = {"GPU": 1}),
                           torch_config=torch_config)
    result = trainer.fit()

The puzzling part for me here is when I tried other similar training scripts in the same environment, they worked.

Issue Severity

High: It blocks me from completing my task.

arunppsg commented 6 months ago

My code had the following line in the train_loop to report metrics:

if i % 20 == 0:
    ray.train.report(metrics={"loss": loss.detach().cpu().item()})

But as the docs says here, train.report() has to be called on each worker. If I am not wrong, the condition statement prevents train.report on each worker as it gets called only in the worker where i % 20 == 0. Removing the condition statement fixes the issue.