huggingface / optimum-neuron

Easy, fast and very cheap training and inference on AWS Trainium and Inferentia chips.
Apache License 2.0
205 stars 60 forks source link

Unable to resume training after saving checkpoint, while using Zero-1 Optimization #694

Open unography opened 1 month ago

unography commented 1 month ago

System Info

Platform:

- Platform: Linux-5.15.0-1056-aws-x86_64-with-glibc2.29
- Python version: 3.8.10

Python packages:

- `optimum-neuron` version: 0.0.24
- `neuron-sdk` version: 2.19.1
- `optimum` version: 1.20.0
- `transformers` version: 4.41.1
- `huggingface_hub` version: 0.24.5
- `torch` version: 2.1.2+cu121
- `aws-neuronx-runtime-discovery` version: 2.9
- `libneuronxla` version: 2.0.2335
- `neuronx-cc` version: 2.14.227.0+2d4f85be
- `neuronx-distributed` version: 0.8.0
- `neuronx-hwm` version: NA
- `torch-neuronx` version: 2.1.2.2.2.0
- `torch-xla` version: 2.1.3
- `transformers-neuronx` version: 0.11.351

Neuron Driver:

WARNING: apt does not have a stable CLI interface. Use with caution in scripts.

aws-neuronx-collectives/now 2.20.22.0-c101c322e amd64 [installed,local]
aws-neuronx-dkms/now 2.16.7.0 amd64 [installed,local]
aws-neuronx-oci-hook/now 2.3.0.0 amd64 [installed,local]
aws-neuronx-runtime-lib/now 2.20.22.0-1b3ca6425 amd64 [installed,local]
aws-neuronx-tools/now 2.17.1.0 amd64 [installed,local]

Who can help?

@michaelbenayoun

Information

Tasks

Reproduction (minimal, reproducible, runnable)

Pretraining a TinyLLama like model, using the same tokenizer as TinyLlama on the wikitext dataset.

Create a Trainium instance following the steps here

Get the official training script:

wget https://raw.githubusercontent.com/huggingface/optimum-neuron/main/examples/language-modeling/run_clm.py

Export BF16 var

export XLA_USE_BF16=true

Compile:

neuron_parallel_compile torchrun --nproc_per_node=2 run_clm.py --config_name dhruv-huq/tiny-tinyllama --tokenizer_name TinyLlama/TinyLlama-1.1B-Chat-v1.0 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train --per_device_train_batch_size 1 --block_size 1024 --bf16 --zero_1 --overwrite_output_dir --output_dir my_training/ --preprocessing_num_workers 64

Train:

torchrun --nproc_per_node=2 run_clm.py --config_name dhruv-huq/tiny-tinyllama --tokenizer_name TinyLlama/TinyLlama-1.1B-Chat-v1.0 --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train --per_device_train_batch_size 1 --block_size 1024 --bf16 --zero_1 --overwrite_output_dir --output_dir my_training/ --preprocessing_num_workers 64

Error with stack trace:

[INFO|trainers.py:832] 2024-09-09 15:03:21,372 >> ***** Running training *****
[INFO|trainers.py:833] 2024-09-09 15:03:21,372 >>   Num examples = 1,380
[INFO|trainers.py:834] 2024-09-09 15:03:21,372 >>   Num Epochs = 3
[INFO|trainers.py:835] 2024-09-09 15:03:21,372 >>   Instantaneous batch size per device = 1
[INFO|trainers.py:840] 2024-09-09 15:03:21,373 >>   Total train batch size (w. parallel, distributed & accumulation) = 2
[INFO|trainers.py:843] 2024-09-09 15:03:21,373 >>   Gradient Accumulation steps = 1
[INFO|trainers.py:844] 2024-09-09 15:03:21,373 >>   Total optimization steps = 4,140
[INFO|trainers.py:845] 2024-09-09 15:03:21,373 >>   Number of trainable parameters = 336,611,328
 12%|████████████████████████████                                                                                                                                                                                                            | 500/4140 [00:57<06:32,  9.27it/s][INFO|trainers.py:504] 2024-09-09 15:04:19,002 >> Saving model checkpoint to my_training/checkpoint-500
[INFO|configuration_utils.py:472] 2024-09-09 15:04:20,099 >> Configuration saved in my_training/checkpoint-500/config.json
[INFO|configuration_utils.py:731] 2024-09-09 15:04:20,099 >> Configuration saved in my_training/checkpoint-500/generation_config.json
[INFO|modeling_utils.py:2618] 2024-09-09 15:04:25,503 >> Model weights saved in my_training/checkpoint-500/model.safetensors
[INFO|tokenization_utils_base.py:2513] 2024-09-09 15:04:25,739 >> tokenizer config file saved in my_training/checkpoint-500/tokenizer_config.json
[INFO|tokenization_utils_base.py:2522] 2024-09-09 15:04:25,739 >> Special tokens file saved in my_training/checkpoint-500/special_tokens_map.json
Traceback (most recent call last):
  File "run_clm.py", line 673, in <module>
    main()
  File "run_clm.py", line 621, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/optimum/neuron/trainers.py", line 1414, in train
    result = super().train(
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/transformers/trainer.py", line 1885, in train
    return inner_training_loop(
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/optimum/neuron/utils/require_utils.py", line 51, in wrapper
    return func(*args, **kwargs)
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/optimum/neuron/trainers.py", line 1054, in _inner_training_loop
    self.optimizer.step()
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/optimum/neuron/utils/require_utils.py", line 51, in wrapper
    return func(*args, **kwargs)
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/optimum/neuron/accelerate/optimizer.py", line 104, in step
    self.optimizer.step(closure)
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch/optim/optimizer.py", line 373, in wrapper
    out = func(*args, **kwargs)
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch_xla/distributed/zero_redundancy_optimizer.py", line 390, in step
    self._clip_grad_norm(max_norm=self.max_norm)
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/neuronx_distributed/optimizer/zero_redundancy_optimizer.py", line 74, in _clip_grad_norm
    self._grad_norm = clip_grad_norm(
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/neuronx_distributed/parallel_layers/grads.py", line 172, in clip_grad_norm
    total_norm = get_grad_norm(
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/neuronx_distributed/parallel_layers/grads.py", line 90, in get_grad_norm
    dtype = parameters[0].dtype
IndexError: list index out of range
[2024-09-09 15:04:29,914] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1768 closing signal SIGTERM
[2024-09-09 15:04:31,079] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 1 (pid: 1769) of binary: /opt/aws_neuron_venv_pytorch/bin/python3.8
Traceback (most recent call last):
  File "/opt/aws_neuron_venv_pytorch/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch/distributed/run.py", line 806, in main
    run(args)
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/aws_neuron_venv_pytorch/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
run_clm.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-09-09_15:04:29
  host      : ip-172-31-54-170.us-west-2.compute.internal
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 1769)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

There is no error if zero1 isn't used, training runs as expected

Expected behavior

Model resumes training correctly after saving the checkpoint

wolanlu commented 1 week ago

+1 similar error on optimum-neuron version: 0.0.25 and Neuron SDK 2.20

Traceback (most recent call last):
  File "/home/ubuntu/ml-specialized-hardware/purpose-built-accelerators/notebooks/src/train.py", line 101, in <module>
    trainer.train()
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/optimum/neuron/trainers.py", line 1456, in train
    result = super().train(
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/optimum/neuron/utils/require_utils.py", line 51, in wrapper
    return func(*args, **kwargs)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/optimum/neuron/trainers.py", line 1096, in _inner_training_loop
    self.optimizer.step()
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/optimum/neuron/utils/require_utils.py", line 51, in wrapper
    return func(*args, **kwargs)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/optimum/neuron/accelerate/optimizer.py", line 104, in step
    self.optimizer.step(closure)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch/optim/optimizer.py", line 373, in wrapper
    out = func(*args, **kwargs)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch_xla/distributed/zero_redundancy_optimizer.py", line 336, in step
    self._clip_grad_norm(max_norm=self.max_norm)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/neuronx_distributed/optimizer/zero_redundancy_optimizer.py", line 98, in _clip_grad_norm
    all_parameters, self._grad_norm = self._get_params_and_grad_norm(norm_type)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/neuronx_distributed/optimizer/zero_redundancy_optimizer.py", line 83, in _get_params_and_grad_norm
    grad_norm = get_grad_norm(
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/grads.py", line 116, in get_grad_norm
    device = parameters[0].device
IndexError: list index out of range
 98%|█████████▊| 1000/1022 [02:22<00:03,  7.01it/s]
[2024-10-15 14:16:00,812] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 7613) of binary: /opt/aws_neuronx_venv_pytorch_2_1/bin/python3
Traceback (most recent call last):
  File "/opt/aws_neuronx_venv_pytorch_2_1/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch/distributed/run.py", line 806, in main
    run(args)
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/aws_neuronx_venv_pytorch_2_1/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: