pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.04k stars 371 forks source link

torch.distributed.elastic.multiprocessing.errors.ChildFailedError #1710

Open Vattikondadheeraj opened 4 hours ago

Vattikondadheeraj commented 4 hours ago

Context :- I am trying to run distributed training on 2 A-100 gpus with 40GB of VRAM. The batch size is 3 and gradient accumulation=1. I have attached the config file below for more details and the error as well. I thing is I am not able to pinpoint the problem here because the error message itself is unclear. Is this because of CUDA memory issue? Or am I missing something else?

INFO:torchtune.utils._logging:Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 3
checkpointer:
  _component_: torchtune.training.FullModelMetaCheckpointer
  checkpoint_dir: /home/toolkit/scratch/LLMcode/Train/llama-3.1-instruct/original
  checkpoint_files:
  - consolidated.00.pth
  model_type: LLAMA3
  output_dir: /home/toolkit/scratch/LLMcode/Checkpoints/full_finetuning_results-1
  recipe_checkpoint: null
custom_sharded_layers:
- tok_embeddings
- output
dataset:
  _component_: torchtune.datasets.alpaca_dataset
device: cuda
dtype: bf16
enable_activation_checkpointing: true
epochs: 10
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /home/toolkit/scratch/LLMcode/Checkpoints/full_finetuning_results-1
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 2.0e-05
output_dir: /home/toolkit/scratch/LLMcode/Checkpoints/full_finetuning_results-1
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /home/toolkit/scratch/LLMcode/Train/llama-3.1-instruct/original/tokenizer.model

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 4088951359. Local seed is seed + rank = 4088951359 + 0
Writing logs to /home/toolkit/scratch/LLMcode/Checkpoints/full_finetuning_results-1/log_1727576992.txt
INFO:torchtune.utils._logging:FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...
INFO:torchtune.utils._logging:Instantiating model and loading checkpoint took 17.02 secs
INFO:torchtune.utils._logging:Memory stats after model init:
        GPU peak memory allocation: 8.52 GiB
        GPU peak memory reserved: 8.67 GiB
        GPU peak memory active: 8.52 GiB
INFO:torchtune.utils._logging:Optimizer is initialized.
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
441
aux1
aux2
  0%|                                                                                                                                                                                           | 0/70 [00:00<?, ?it/s]765
aux1
aux1
/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
W0929 02:30:29.754000 139982797317952 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 2171 closing signal SIGTERM
E0929 02:30:36.777000 139982797317952 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: -9) local_rank: 1 (pid: 2172) of binary: /home/toolkit/.conda/envs/torch/bin/python
Traceback (most recent call last):
  File "/home/toolkit/.conda/envs/torch/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/toolkit/scratch/LLMcode/Train/torchtune/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/home/toolkit/scratch/LLMcode/Train/torchtune/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/home/toolkit/scratch/LLMcode/Train/torchtune/torchtune/_cli/run.py", line 183, in _run_cmd
    self._run_distributed(args)
  File "/home/toolkit/scratch/LLMcode/Train/torchtune/torchtune/_cli/run.py", line 89, in _run_distributed
    run(args)
  File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/home/toolkit/scratch/LLMcode/Train/torchtune/recipes/full_finetune_distributed.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-09-29_02:30:29
  host      : 92b61c39-4dd9-4911-be6f-522a27802a4a
  rank      : 1 (local_rank: 1)
  exitcode  : -9 (pid: 2172)
  error_file: <N/A>
  traceback : Signal 9 (SIGKILL) received by PID 2172
============================================================

Config File :-

tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /home/toolkit/scratch/LLMcode/Train/llama-3.1-instruct/original/tokenizer.model
  max_seq_len: null

# Dataset
dataset:
  _component_: torchtune.datasets.alpaca_dataset
seed: null
shuffle: True

# Model Arguments
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b

checkpointer:
  _component_: torchtune.training.FullModelMetaCheckpointer
  checkpoint_dir: /home/toolkit/scratch/LLMcode/Train/llama-3.1-instruct/original
  checkpoint_files: [
    consolidated.00.pth,
  ]
  recipe_checkpoint: null
  output_dir: /home/toolkit/scratch/LLMcode/Checkpoints/full_finetuning_results-1
  model_type: LLAMA3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 3
epochs: 10

optimizer:
  _component_: torch.optim.AdamW
  lr: 2e-5
  fused: True
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True
custom_sharded_layers: ['tok_embeddings', 'output']

# Reduced precision
dtype: bf16

# Logging
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /home/toolkit/scratch/LLMcode/Checkpoints/full_finetuning_results-1
output_dir: /home/toolkit/scratch/LLMcode/Checkpoints/full_finetuning_results-1
log_every_n_steps: 1
log_peak_memory_stats: True
Vattikondadheeraj commented 4 hours ago

@SagarChandra07 You are asking me to download some random file without any context. I really dont understand your stance here. I request the moderators to delete the comment.

ebsmothers commented 3 hours ago

@Vattikondadheeraj yeah that error message is not too informative, even with distributed training I am used to seeing more. I guess you are getting SIGKILL signal, but I think that can be caused by multiple things. As a debugging step, I would suggest checking whether this succeeds if you run on single-device (can run with batch size 1 to hopefully avoid any OOMs). You can also log the memory stats or try to insert calls to torch.distributed.breakpoint() to your code to better pinpoint where the error is occurring.

Btw sorry about the spam comment, I've deleted it.

Vattikondadheeraj commented 3 hours ago

@ebsmothers , I tried few experiments as you mentioned. The first one is I added the logging scripts as shown below

# os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL'
# os.environ['CUDA_LAUNCH_BLOCKING']="1"
# os.environ['NCLL_DEBUG']="INFO"
# os.environ['TORCH_SHOW_CPP_STACKTRACES']="1"
# os.environ['TORCH_CPP_LOG_LEVEL']="INFO"

After adding them, it ran for one epoch and gave the above error.

  1. The second experiment is to add torch.distributed.breakpoint() into my pipeline, it was working fine. I mean it ran for 4 batches without any problem. Do you think the problem is because of bad synchronization between the 2 gpus caused by dataloading delays or something else? Because in previously it ran for only one batch but after adding "torch.distributed.breakpoint()" and I am trying to run step by step, it was running well.

When I added different print statements, I found out that its failing at optimizer.step function.

Any tips for this error?

Vattikondadheeraj commented 2 hours ago

Hey @ebsmothers , I have pinponted the error in more detailed. I have attached the torch.adamW file along with log file below

def _fused_adamw(
    params: List[Tensor],
    grads: List[Tensor],
    exp_avgs: List[Tensor],
    exp_avg_sqs: List[Tensor],
    max_exp_avg_sqs: List[Tensor],
    state_steps: List[Tensor],
    grad_scale: Optional[Tensor],
    found_inf: Optional[Tensor],
    *,
    amsgrad: bool,
    beta1: float,
    beta2: float,
    lr: Union[Tensor, float],
    weight_decay: float,
    eps: float,
    maximize: bool,
    capturable: bool,  # Needed for consistency.
    differentiable: bool,
    has_complex: bool,  # Needed for consistency.
) -> None:
    if not params:
        print("problem-1111111111")
        return
    if differentiable:
        print("problem-22222222222")
        raise RuntimeError("Adam with fused=True does not support differentiable=True")

    grad_scale_dict: DeviceDict = (
        {grad_scale.device: grad_scale} if grad_scale is not None else {}
    )

    print("dhedhe-11111111111111")
    found_inf_dict: DeviceDict = (
        {found_inf.device: found_inf} if found_inf is not None else {}
    )

    # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
    # treating it as a scalar.
    lr_dict: Optional[DeviceDict] = (
        {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
    )

    print("dhedhe-22222222222")

    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
        [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]
    )
    for (device, _), (
        (
            device_params,
            device_grads,
            device_exp_avgs,
            device_exp_avg_sqs,
            device_max_exp_avg_sqs,
            device_state_steps,
        ),
        _,
    ) in grouped_tensors.items():
        device_grad_scale, device_found_inf = None, None
        print("dhedhe-3333333333333")
        if grad_scale is not None:
            device_grad_scale = grad_scale_dict.setdefault(
                device, grad_scale.to(device, non_blocking=True)
            )
        print("dhedhe-44444444444")
        if found_inf is not None:
            device_found_inf = found_inf_dict.setdefault(
                device, found_inf.to(device, non_blocking=True)
            )
        print("dhedhe-55555555555")
        if lr_dict is not None and device not in lr_dict:
            lr = lr_dict.setdefault(
                device, lr.to(device=device, non_blocking=True)  # type: ignore[union-attr]
            )
        print("dhedhe-66666666666")
        torch._foreach_add_(device_state_steps, 1)
        torch._fused_adamw_(
            device_params,
            device_grads,
            device_exp_avgs,
            device_exp_avg_sqs,
            device_max_exp_avg_sqs,
            device_state_steps,
            amsgrad=amsgrad,
            lr=lr,
            beta1=beta1,
            beta2=beta2,
            weight_decay=weight_decay,
            eps=eps,
            maximize=maximize,
            grad_scale=device_grad_scale,
            found_inf=device_found_inf,
        )
        if device_found_inf is not None:
            torch._foreach_sub_(
                device_state_steps, [device_found_inf] * len(device_state_steps)
            )

The log file is

batch_size: 3
checkpointer:
  _component_: torchtune.training.FullModelMetaCheckpointer
  checkpoint_dir: /home/toolkit/scratch/LLMcode/Train/llama-3.1-instruct/original
  checkpoint_files:
  - consolidated.00.pth
  model_type: LLAMA3
  output_dir: /home/toolkit/scratch/LLMcode/Checkpoints/full_finetuning_results-1
  recipe_checkpoint: null
custom_sharded_layers:
- tok_embeddings
- output
dataset:
  _component_: torchtune.datasets.alpaca_dataset
device: cuda
dtype: bf16
enable_activation_checkpointing: true
epochs: 10
gradient_accumulation_steps: 3
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /home/toolkit/scratch/LLMcode/Checkpoints/full_finetuning_results-1
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 2.0e-05
output_dir: /home/toolkit/scratch/LLMcode/Checkpoints/full_finetuning_results-1
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /home/toolkit/scratch/LLMcode/Train/llama-3.1-instruct/original/tokenizer.model

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 457363046. Local seed is seed + rank = 457363046 + 0
Writing logs to /home/toolkit/scratch/LLMcode/Checkpoints/full_finetuning_results-1/log_1727584924.txt
INFO:torchtune.utils._logging:FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...
INFO:torchtune.utils._logging:Instantiating model and loading checkpoint took 4.41 secs
INFO:torchtune.utils._logging:Memory stats after model init:
        GPU peak memory allocation: 8.52 GiB
        GPU peak memory reserved: 8.67 GiB
        GPU peak memory active: 8.52 GiB
INFO:torchtune.utils._logging:Optimizer is initialized.
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}

!!! ATTENTION !!!

Type 'up' to get to the frame that called dist.breakpoint(rank=0)

> /home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/distributed/__init__.py(94)breakpoint()
-> meta_in_tls = torch._C._meta_in_tls_dispatch_include()
(Pdb) c
  0%|                                                                                                                                                                                           | 0/23 [00:00<?, ?it/s]463
465
aux1
aux1
aux2
aux1
hihihihihihihhihihihi----1111111
/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
hihihihihihihhihihihi----1111111
/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
hihihihihihihhihihihi----222222222
hihihihihihihhihihihi----222222222
464
104
aux1
aux1
aux2
aux2
hihihihihihihhihihihi----1111111
hihihihihihihhihihihi----1111111
hihihihihihihhihihihi----222222222
hihihihihihihhihihihi----222222222
226
222
aux2
aux1
aux1
aux1
hihihihihihihhihihihi----1111111
hihihihihihihhihihihi----1111111
hihihihihihihhihihihi----222222222
hihihihihihihhihihihi----3333333
hihihihihihihhihihihi----222222222
hihihihihihihhihihihi----3333333
byebyebye-1111111111
byebyebye-1111111111
byebyebye-222222222222
midmidmidmdimdidmdimdimdidmdi <function _fused_adamw at 0x7f5d9bd94180>
dhedhe-11111111111111
dhedhe-22222222222
dhedhe-3333333333333
dhedhe-44444444444
dhedhe-55555555555
dhedhe-66666666666
byebyebye-222222222222
midmidmidmdimdidmdimdimdidmdi <function _fused_adamw at 0x7f8798f98180>
dhedhe-11111111111111
dhedhe-22222222222
dhedhe-3333333333333
dhedhe-44444444444
dhedhe-55555555555
dhedhe-66666666666
W0929 04:42:35.154000 140506051815232 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 11643 closing signal SIGTERM
E0929 04:42:42.132000 140506051815232 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: -9) local_rank: 1 (pid: 11644) of binary: /home/toolkit/.conda/envs/torch/bin/python
Traceback (most recent call last):
  File "/home/toolkit/.conda/envs/torch/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/toolkit/scratch/LLMcode/Train/torchtune/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/home/toolkit/scratch/LLMcode/Train/torchtune/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/home/toolkit/scratch/LLMcode/Train/torchtune/torchtune/_cli/run.py", line 183, in _run_cmd
    self._run_distributed(args)
  File "/home/toolkit/scratch/LLMcode/Train/torchtune/torchtune/_cli/run.py", line 89, in _run_distributed
    run(args)
  File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/home/toolkit/scratch/LLMcode/Train/torchtune/recipes/full_finetune_distributed.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-09-29_04:42:35
  host      : 92b61c39-4dd9-4911-be6f-522a27802a4a
  rank      : 1 (local_rank: 1)
  exitcode  : -9 (pid: 11644)
  error_file: <N/A>
  traceback : Signal 9 (SIGKILL) received by PID 11644
============================================================