pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 459 forks source link

Distributed training on multi-host v4/v5 TPU Pods is too slow #8020

Open ayukh opened 3 days ago

ayukh commented 3 days ago

❓ Questions and Help

Hi! We are trying to train Gemma-2-9B on v4-64 and v5-128 Pod as mentioned in this comment. We use FSDP+SPMD setup on torch XLA 2.4.0 and transformers 4.44.2. We have tested our own code with default HuggingFace Trainer and xmp.spawn launch and also HuggingFace example run_clm.py according to this GCP documentation - everything causes very slow training. We observe that the batch size per device is equal to total batch size:

WARNING:run_clm:Process rank: 1, device: xla:0, n_gpu: 0, distributed training: False, 16-bits training: False
Instantaneous batch size per device = 16
Total train batch size (w. parallel, distributed & accumulation) = 16
  1. Distributed type is set to tpu and therefore is marked as False in the output above. Is total train batch size supposed to be equal to batch size per device in this case?
  2. If we have a bigger batch size per device we quickly get OOM error. Using small batch size we observe the training to be extremely slow with train_samples_per_second being around 0.606. Does this mean that the training is not parallel/distributed?
  3. Are there any benchmarks on training similarly sized LLMs with the same setup?

We can't seem to find an issue on what can cause such slow performance on TPUs - we use a default example and it is not working for us. Would be thankful for any help here!

JackCaoG commented 2 days ago

batch size per device is the global batch size is due to in SPMD conceptually there is only one device. XLA handles sending the correct batch to the actual physical device.

We never tried training gemma2, and I vaguely remember it uses sliding windows attention, so I am not sure if that's the bottleneck. Maybe try run with PT_XLA_DEBUG=1 env var and see if there is a graph break graph recompilation issue. Also as I said in the previous issue, can you take a profile and share it? https://github.com/pytorch/xla/blob/master/examples/debug/train_resnet_profile.py is a short example. I also had 2 videos https://www.youtube.com/watch?v=LK3A3vjo-KQ&t=924s and https://www.youtube.com/watch?v=40jYVhQHGEA&t=1030s explain how to take/understand the profile and basic debugging tips.

giuliano-97 commented 18 minutes ago

Hi @JackCaoG, I'm working with Hanna and I've tried to collect some information that can hopefully help troubleshoot the issue. It's quite a bit so I will start from what you asked i.e. running with PT_XLA_DEBUG=1. I'm going to be as detailed as I can.

For starters, I'm using a TPU v4-32 pod. The environment was setup with a script like this:

#!/bin/bash

source ./hf_env/bin/activate

# Step 1: install torch, torch-xla, libtpu
pip install torch~=2.4.0 --index-url https://download.pytorch.org/whl/cpu
pip install torch_xla[tpu]~=2.4.0 -f https://storage.googleapis.com/libtpu-releases/index.html

# Step 2: install HF
pip3 install datasets accelerate evaluate scikit-learn

git clone https://github.com/huggingface/transformers -b v4.44.2 transformers-4.44.2
cd transformers-4.44.2 
pip install -e .

So I'm using the run_clm.py example script from HF transformers 4.44.2. In order to enable SPMD and profiling I applied to the following patches to examples/pytorch/language-modeling/run_clm.py:

diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py
index 15ab7c4c4..18ef59831 100755
--- a/examples/pytorch/language-modeling/run_clm.py
+++ b/examples/pytorch/language-modeling/run_clm.py
@@ -53,6 +53,12 @@ from transformers.trainer_utils import get_last_checkpoint
 from transformers.utils import check_min_version, send_example_telemetry
 from transformers.utils.versions import require_version

+import torch_xla
+import torch_xla.debug.profiler as xp
+import torch_xla.runtime as xr
+
+xr.use_spmd()

 # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 check_min_version("4.44.0")
@@ -295,6 +301,9 @@ def main():
     # Set seed before initializing model.
     set_seed(training_args.seed)

+    server = xp.start_server(9012)
+    logger.info(f'Profiling server started: {str(server)}')
+
     # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
     # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
     # (the dataset will be downloaded automatically from the datasets Hub).

and the src/transformers/trainer.py:

diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 68ba7babf..7cd948e5a 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -34,6 +34,7 @@ import time
 import warnings
 from collections.abc import Mapping
 from pathlib import Path
+from threading import Thread
 from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

@@ -188,6 +189,7 @@ if is_torch_xla_available():
     import torch_xla.core.xla_model as xm
     import torch_xla.debug.metrics as met
     from torch_xla import __version__ as XLA_VERSION
+    import torch_xla.debug.profiler as xp

     IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
     if IS_XLA_FSDPV2_POST_2_2:
@@ -702,6 +704,7 @@ class Trainer:
             # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
             # Tensor axis is just a placeholder where it will not be used in FSDPv2.
             num_devices = xr.global_runtime_device_count()
+            logger.info(f"Global device count: {num_devices}")
             xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))

     def _activate_neftune(self, model):
@@ -2206,6 +2209,7 @@ class Trainer:

         total_batched_samples = 0
         for epoch in range(epochs_trained, num_train_epochs):
+            epoch_start_time = time.time()
             epoch_iterator = train_dataloader
             if hasattr(epoch_iterator, "set_epoch"):
                 epoch_iterator.set_epoch(epoch)
@@ -2233,6 +2237,11 @@ class Trainer:
                 rng_to_sync = True

             step = -1
+            profile_step = int(os.environ.get('PROFILE_STEP', -1))
+            profile_epoch = int(os.environ.get('PROFILE_EPOCH', -1))
+            profile_duration = int(os.environ.get('PROFILE_DURATION_MS', 20000))
+            profile_logdir = os.environ.get('PROFILE_LOGDIR', None)
+            logger.info("Profiling will start at step {} and epoch {}".format(profile_step, profile_epoch))   
             for step, inputs in enumerate(epoch_iterator):
                 total_batched_samples += 1

@@ -2274,9 +2283,14 @@ class Trainer:

                 if step % args.gradient_accumulation_steps == 0:
                     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
+                
+                if step == profile_step and epoch == profile_epoch:
+                    trace = lambda: xp.trace('127.0.0.1:9012', profile_logdir, profile_duration)
+                    Thread(target=trace).start()

                 with self.accelerator.accumulate(model):
-                    tr_loss_step = self.training_step(model, inputs)
+                    with xp.StepTrace("Training_step", step_num=step):
+                        tr_loss_step = self.training_step(model, inputs)

                 if (
                     args.logging_nan_inf_filter
@@ -2384,6 +2398,14 @@ class Trainer:
                         "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                         "configured. Check your training configuration if this is unexpected."
                     )
+            epoch_speed_metrics = speed_metrics(
+                "epoch",
+                epoch_start_time,
+                num_samples=num_train_samples / args.num_train_epochs,
+                num_tokens=num_train_tokens / args.num_train_epochs if num_train_tokens is not None else None,
+            )
+            self.log(epoch_speed_metrics)
+
             if self.control.should_training_stop:
                 break

Then I used this script to launch the training:

#!/bin/bash

. ~/hf_env/bin/activate

export PJRT_DEVICE=TPU

export PROFILE_STEP=20
export PROFILE_EPOCH=0
export PROFILE_DURATION_MS=300000
export PROFILE_LOGDIR=/home/giuliano/xla-exp/profiles-run-clm
export PT_XLA_DEBUG=1
export XLA_FLAGS="--xla_dump_to=/home/giuliano/xla-exp/dumps"
export XLA_IR_DEBUG=1
export XLA_HLO_DEBUG=1

# Run
cd ~/transformers-4.44.2

examples/pytorch/language-modeling/run_clm.py \
  --log_level info \
  --debug tpu_metrics_debug \
  --include_tokens_per_second \
  --dataloader_drop_last \
  --token MY_TOKEN \
  --model_name_or_path google/gemma-2-9b \
  --dataset_name wikitext \
  --dataset_config_name wikitext-2-raw-v1 \
  --per_device_train_batch_size 64 \
  --num_train_epochs 4 \
  --block_size 512 \
  --optim adamw_torch \
  --do_train \
  --save_strategy no \
  --logging_strategy no \
  --fsdp full_shard \
  --fsdp_config fsdp_config.json \
  --output_dir /home/giuliano/xla-exp \
  --overwrite_output_dir

where fsdp_config.json looks like this:

{
    "fsdp_transformer_layer_cls_to_wrap": [
        "Gemma2DecoderLayer"
    ],
    "min_num_params": 0,
    "xla": true,
    "xla_fsdp_v2": true,
    "xla_fsdp_grad_ckpt": true
}

The training was launched on all the workers as typicaly for a tpu pod:

gcloud compute tpus tpu-vm ssh $TPU_NAME --worker=all --zone=$TPU_ZONE --command="bash transformers-4.44.2/train_gemma2.sh"

Unfortunately, that goes OOM by a few MBs and here is the full output:

[INFO|trainer.py:2137] 2024-09-19 13:58:38,488 >> ***** Running training *****
[INFO|trainer.py:2138] 2024-09-19 13:58:38,489 >>   Num examples = 4,800
[INFO|trainer.py:2139] 2024-09-19 13:58:38,489 >>   Num Epochs = 4
[INFO|trainer.py:2140] 2024-09-19 13:58:38,489 >>   Instantaneous batch size per device = 64
[INFO|trainer.py:2143] 2024-09-19 13:58:38,489 >>   Total train batch size (w. parallel, distributed & accumulation) = 64
[INFO|trainer.py:2144] 2024-09-19 13:58:38,489 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2145] 2024-09-19 13:58:38,489 >>   Total optimization steps = 300
[INFO|trainer.py:2146] 2024-09-19 13:58:38,490 >>   Number of trainable parameters = 10,159,209,984
  0%|          | 0/300 [00:00<?, ?it/s][INFO|trainer.py:2244] 2024-09-19 13:58:38,493 >> Profiling will start at step 20 and epoch 0
[INFO|trainer.py:2137] 2024-09-19 13:58:38,662 >> ***** Running training *****
[INFO|trainer.py:2138] 2024-09-19 13:58:38,662 >>   Num examples = 4,800
[INFO|trainer.py:2139] 2024-09-19 13:58:38,662 >>   Num Epochs = 4
[INFO|trainer.py:2140] 2024-09-19 13:58:38,662 >>   Instantaneous batch size per device = 64
[INFO|trainer.py:2143] 2024-09-19 13:58:38,662 >>   Total train batch size (w. parallel, distributed & accumulation) = 64
[INFO|trainer.py:2144] 2024-09-19 13:58:38,662 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2145] 2024-09-19 13:58:38,662 >>   Total optimization steps = 300
[INFO|trainer.py:2146] 2024-09-19 13:58:38,664 >>   Number of trainable parameters = 10,159,209,984
  0%|          | 0/300 [00:00<?, ?it/s][INFO|trainer.py:2244] 2024-09-19 13:58:38,666 >> Profiling will start at step 20 and epoch 0
[INFO|trainer.py:2137] 2024-09-19 13:58:38,964 >> ***** Running training *****
[INFO|trainer.py:2138] 2024-09-19 13:58:38,964 >>   Num examples = 4,800
[INFO|trainer.py:2139] 2024-09-19 13:58:38,964 >>   Num Epochs = 4
[INFO|trainer.py:2140] 2024-09-19 13:58:38,964 >>   Instantaneous batch size per device = 64
[INFO|trainer.py:2143] 2024-09-19 13:58:38,964 >>   Total train batch size (w. parallel, distributed & accumulation) = 64
[INFO|trainer.py:2144] 2024-09-19 13:58:38,964 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2145] 2024-09-19 13:58:38,964 >>   Total optimization steps = 300
[INFO|trainer.py:2146] 2024-09-19 13:58:38,966 >>   Number of trainable parameters = 10,159,209,984
  0%|          | 0/300 [00:00<?, ?it/s][INFO|trainer.py:2244] 2024-09-19 13:58:38,968 >> Profiling will start at step 20 and epoch 0
[INFO|trainer.py:2137] 2024-09-19 13:58:39,253 >> ***** Running training *****
[INFO|trainer.py:2138] 2024-09-19 13:58:39,253 >>   Num examples = 4,800
[INFO|trainer.py:2139] 2024-09-19 13:58:39,253 >>   Num Epochs = 4
[INFO|trainer.py:2140] 2024-09-19 13:58:39,253 >>   Instantaneous batch size per device = 64
[INFO|trainer.py:2143] 2024-09-19 13:58:39,253 >>   Total train batch size (w. parallel, distributed & accumulation) = 64
[INFO|trainer.py:2144] 2024-09-19 13:58:39,253 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:2145] 2024-09-19 13:58:39,253 >>   Total optimization steps = 300
[INFO|trainer.py:2146] 2024-09-19 13:58:39,255 >>   Number of trainable parameters = 10,159,209,984
  0%|          | 0/300 [00:00<?, ?it/s][INFO|trainer.py:2244] 2024-09-19 13:58:39,257 >> Profiling will start at step 20 and epoch 0
/home/giuliano/hf_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1623: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  warnings.warn("For backward hooks to be called,"
/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/utils/checkpoint.py:171: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/utils/checkpoint.py:172: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
/home/giuliano/hf_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1623: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  warnings.warn("For backward hooks to be called,"
/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/utils/checkpoint.py:171: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/utils/checkpoint.py:172: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
/home/giuliano/hf_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1623: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  warnings.warn("For backward hooks to be called,"
/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/utils/checkpoint.py:171: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/utils/checkpoint.py:172: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
/home/giuliano/hf_env/lib/python3.10/site-packages/torch/nn/modules/module.py:1623: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
  warnings.warn("For backward hooks to be called,"
/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/utils/checkpoint.py:171: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/utils/checkpoint.py:172: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):

Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis:   mark_step when exiting a profiler StepTrace region
Compilation Analysis: Graph Info: 
Compilation Analysis:   Graph Hash: 1f46bb607822f83dfbaaead361199c7e
Compilation Analysis:   Number of Graph Inputs: 521
Compilation Analysis:   Number of Graph Outputs: 1434
Compilation Analysis: Python Frame Triggered Execution: 
Compilation Analysis:   mark_step (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1055)
Compilation Analysis:   __exit__ (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/debug/profiler.py:170)
Compilation Analysis:   _inner_training_loop (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:2292)
Compilation Analysis:   train (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:1941)
Compilation Analysis:   main (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:614)
Compilation Analysis:   <module> (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:666)
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================

Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis:   mark_step when exiting a profiler StepTrace region
Compilation Analysis: Graph Info: 
Compilation Analysis:   Graph Hash: 1f46bb607822f83dfbaaead361199c7e
Compilation Analysis:   Number of Graph Inputs: 521
Compilation Analysis:   Number of Graph Outputs: 1434
Compilation Analysis: Python Frame Triggered Execution: 
Compilation Analysis:   mark_step (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1055)
Compilation Analysis:   __exit__ (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/debug/profiler.py:170)
Compilation Analysis:   _inner_training_loop (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:2292)
Compilation Analysis:   train (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:1941)
Compilation Analysis:   main (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:614)
Compilation Analysis:   <module> (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:666)
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================

Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis:   mark_step when exiting a profiler StepTrace region
Compilation Analysis: Graph Info: 
Compilation Analysis:   Graph Hash: 1f46bb607822f83dfbaaead361199c7e
Compilation Analysis:   Number of Graph Inputs: 521
Compilation Analysis:   Number of Graph Outputs: 1434
Compilation Analysis: Python Frame Triggered Execution: 
Compilation Analysis:   mark_step (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1055)
Compilation Analysis:   __exit__ (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/debug/profiler.py:170)
Compilation Analysis:   _inner_training_loop (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:2292)
Compilation Analysis:   train (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:1941)
Compilation Analysis:   main (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:614)
Compilation Analysis:   <module> (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:666)
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================

Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis:   mark_step when exiting a profiler StepTrace region
Compilation Analysis: Graph Info: 
Compilation Analysis:   Graph Hash: 1f46bb607822f83dfbaaead361199c7e
Compilation Analysis:   Number of Graph Inputs: 521
Compilation Analysis:   Number of Graph Outputs: 1434
Compilation Analysis: Python Frame Triggered Execution: 
Compilation Analysis:   mark_step (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1055)
Compilation Analysis:   __exit__ (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/debug/profiler.py:170)
Compilation Analysis:   _inner_training_loop (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:2292)
Compilation Analysis:   train (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:1941)
Compilation Analysis:   main (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:614)
Compilation Analysis:   <module> (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:666)
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================

Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 2.365468 GB
Post Compilation Analysis: Graph output size: 67.803850 GB
Post Compilation Analysis: Aliased Input size: 1.938169 GB
Post Compilation Analysis: Intermediate tensor size: 5.577612 GB
Post Compilation Analysis: Compiled program size: 0.275043 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis:   mark_step when exiting a profiler StepTrace region
Execution Analysis: Graph Info: 
Execution Analysis:   Graph Hash: 1f46bb607822f83dfbaaead361199c7e
Execution Analysis:   Number of Graph Inputs: 521
Execution Analysis:   Number of Graph Outputs: 1434
Execution Analysis: Python Frame Triggered Execution: 
Execution Analysis:   mark_step (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1055)
Execution Analysis:   __exit__ (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/debug/profiler.py:170)
Execution Analysis:   _inner_training_loop (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:2292)
Execution Analysis:   train (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:1941)
Execution Analysis:   main (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:614)
Execution Analysis:   <module> (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:666)
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
  0%|          | 1/300 [12:22<61:41:02, 742.68s/it]pt-xla-profiler: CompileTime too slow: longest instance took 12m17s215ms206.329us. Please open a GitHub issue with the graph dump for our team to optimize.
Traceback (most recent call last):
  File "/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py", line 666, in <module>
    main()
  File "/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py", line 614, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/home/giuliano/transformers-4.44.2/src/transformers/trainer.py", line 1941, in train
    return inner_training_loop(
  File "/home/giuliano/transformers-4.44.2/src/transformers/trainer.py", line 2291, in _inner_training_loop
    with self.accelerator.accumulate(model):
  File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__
    return next(self.gen)
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/accelerate/accelerator.py", line 1073, in accumulate
    self._do_sync()
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/accelerate/accelerator.py", line 1027, in _do_sync
    self.gradient_state._set_sync_gradients((self.step % self.gradient_state.num_steps) == 0)
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/accelerate/state.py", line 1214, in _set_sync_gradients
    xm.mark_step()
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1055, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 196.00M. That was not possible. There are 134.44M free.; (1x0x0_HBM0): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:================================================================================
  0%|          | 1/300 [12:22<61:42:17, 742.93s/it]
##### Command execution on worker 3 failed with exit status 1. Continuing.

Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 2.365468 GB
Post Compilation Analysis: Graph output size: 67.803850 GB
Post Compilation Analysis: Aliased Input size: 1.938169 GB
Post Compilation Analysis: Intermediate tensor size: 5.577612 GB
Post Compilation Analysis: Compiled program size: 0.275043 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis:   mark_step when exiting a profiler StepTrace region
Execution Analysis: Graph Info: 
Execution Analysis:   Graph Hash: 1f46bb607822f83dfbaaead361199c7e
Execution Analysis:   Number of Graph Inputs: 521
Execution Analysis:   Number of Graph Outputs: 1434
Execution Analysis: Python Frame Triggered Execution: 
Execution Analysis:   mark_step (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1055)
Execution Analysis:   __exit__ (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/debug/profiler.py:170)
Execution Analysis:   _inner_training_loop (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:2292)
Execution Analysis:   train (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:1941)
Execution Analysis:   main (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:614)
Execution Analysis:   <module> (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:666)
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
  0%|          | 1/300 [12:29<62:13:10, 749.13s/it]pt-xla-profiler: CompileTime too slow: longest instance took 12m24s661ms257.954us. Please open a GitHub issue with the graph dump for our team to optimize.
Traceback (most recent call last):
  File "/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py", line 666, in <module>
    main()
  File "/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py", line 614, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/home/giuliano/transformers-4.44.2/src/transformers/trainer.py", line 1941, in train
    return inner_training_loop(
  File "/home/giuliano/transformers-4.44.2/src/transformers/trainer.py", line 2291, in _inner_training_loop
    with self.accelerator.accumulate(model):
  File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__
    return next(self.gen)
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/accelerate/accelerator.py", line 1073, in accumulate
    self._do_sync()
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/accelerate/accelerator.py", line 1027, in _do_sync
    self.gradient_state._set_sync_gradients((self.step % self.gradient_state.num_steps) == 0)
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/accelerate/state.py", line 1214, in _set_sync_gradients
    xm.mark_step()
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1055, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 196.00M. That was not possible. There are 134.44M free.; (1x0x1_HBM0): while running replica 0 and partition 4 of a replicated computation (other replicas may have failed as well).
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:================================================================================
  0%|          | 1/300 [12:29<62:14:24, 749.38s/it]
##### Command execution on worker 1 failed with exit status 1. Continuing.

Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 2.365468 GB
Post Compilation Analysis: Graph output size: 67.803850 GB
Post Compilation Analysis: Aliased Input size: 1.938169 GB
Post Compilation Analysis: Intermediate tensor size: 5.577612 GB
Post Compilation Analysis: Compiled program size: 0.275043 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis:   mark_step when exiting a profiler StepTrace region
Execution Analysis: Graph Info: 
Execution Analysis:   Graph Hash: 1f46bb607822f83dfbaaead361199c7e
Execution Analysis:   Number of Graph Inputs: 521
Execution Analysis:   Number of Graph Outputs: 1434
Execution Analysis: Python Frame Triggered Execution: 
Execution Analysis:   mark_step (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1055)
Execution Analysis:   __exit__ (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/debug/profiler.py:170)
Execution Analysis:   _inner_training_loop (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:2292)
Execution Analysis:   train (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:1941)
Execution Analysis:   main (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:614)
Execution Analysis:   <module> (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:666)
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
pt-xla-profiler: CompileTime too slow: longest instance took 12m30s585ms189.902us. Please open a GitHub issue with the graph dump for our team to optimize.
  0%|          | 1/300 [12:35<62:42:58, 755.11s/it]Traceback (most recent call last):
  File "/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py", line 666, in <module>
    main()
  File "/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py", line 614, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/home/giuliano/transformers-4.44.2/src/transformers/trainer.py", line 1941, in train
    return inner_training_loop(
  File "/home/giuliano/transformers-4.44.2/src/transformers/trainer.py", line 2291, in _inner_training_loop
    with self.accelerator.accumulate(model):
  File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__
    return next(self.gen)
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/accelerate/accelerator.py", line 1073, in accumulate
    self._do_sync()
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/accelerate/accelerator.py", line 1027, in _do_sync
    self.gradient_state._set_sync_gradients((self.step % self.gradient_state.num_steps) == 0)
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/accelerate/state.py", line 1214, in _set_sync_gradients
    xm.mark_step()
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1055, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 196.00M. That was not possible. There are 134.44M free.; (1x0x3_HBM0): while running replica 0 and partition 12 of a replicated computation (other replicas may have failed as well).
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:================================================================================
  0%|          | 1/300 [12:35<62:44:13, 755.36s/it]
##### Command execution on worker 2 failed with exit status 1. Continuing.

Post Compilation Analysis: ================================================================================
Post Compilation Analysis: Graph input size: 2.365468 GB
Post Compilation Analysis: Graph output size: 67.803850 GB
Post Compilation Analysis: Aliased Input size: 1.938169 GB
Post Compilation Analysis: Intermediate tensor size: 5.577612 GB
Post Compilation Analysis: Compiled program size: 0.275043 GB
Post Compilation Analysis: --------------------------------------------------------------------------------
Post Compilation Analysis: ================================================================================

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis:   mark_step when exiting a profiler StepTrace region
Execution Analysis: Graph Info: 
Execution Analysis:   Graph Hash: 1f46bb607822f83dfbaaead361199c7e
Execution Analysis:   Number of Graph Inputs: 521
Execution Analysis:   Number of Graph Outputs: 1434
Execution Analysis: Python Frame Triggered Execution: 
Execution Analysis:   mark_step (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1055)
Execution Analysis:   __exit__ (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/debug/profiler.py:170)
Execution Analysis:   _inner_training_loop (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:2292)
Execution Analysis:   train (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:1941)
Execution Analysis:   main (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:614)
Execution Analysis:   <module> (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:666)
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
pt-xla-profiler: CompileTime too slow: longest instance took 13m35s865ms460.010us. Please open a GitHub issue with the graph dump for our team to optimize.
  0%|          | 1/300 [12:40<63:08:40, 760.27s/it]Traceback (most recent call last):
  File "/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py", line 666, in <module>
    main()
  File "/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py", line 614, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/home/giuliano/transformers-4.44.2/src/transformers/trainer.py", line 1941, in train
    return inner_training_loop(
  File "/home/giuliano/transformers-4.44.2/src/transformers/trainer.py", line 2291, in _inner_training_loop
    with self.accelerator.accumulate(model):
  File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__
    return next(self.gen)
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/accelerate/accelerator.py", line 1073, in accumulate
    self._do_sync()
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/accelerate/accelerator.py", line 1027, in _do_sync
    self.gradient_state._set_sync_gradients((self.step % self.gradient_state.num_steps) == 0)
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/accelerate/state.py", line 1214, in _set_sync_gradients
    xm.mark_step()
  File "/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1055, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 196.00M. That was not possible. There are 134.44M free.; (1x0x2_HBM0): while running replica 0 and partition 8 of a replicated computation (other replicas may have failed as well).
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:================================================================================
  0%|          | 1/300 [12:40<63:09:57, 760.53s/it]
##### Command execution on worker 0 failed with exit status 1. Continuing.

Now the OOM is one problem (after replicating the same training loop without all the bell and whistles of HF trainer I am able to make it work, but I'll talk about that later), but the part that I want to highlight is this:

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis:   mark_step when exiting a profiler StepTrace region
Execution Analysis: Graph Info: 
Execution Analysis:   Graph Hash: 1f46bb607822f83dfbaaead361199c7e
Execution Analysis:   Number of Graph Inputs: 521
Execution Analysis:   Number of Graph Outputs: 1434
Execution Analysis: Python Frame Triggered Execution: 
Execution Analysis:   mark_step (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1055)
Execution Analysis:   __exit__ (/home/giuliano/hf_env/lib/python3.10/site-packages/torch_xla/debug/profiler.py:170)
Execution Analysis:   _inner_training_loop (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:2292)
Execution Analysis:   train (/home/giuliano/transformers-4.44.2/src/transformers/trainer.py:1941)
Execution Analysis:   main (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:614)
Execution Analysis:   <module> (/home/giuliano/transformers-4.44.2/examples/pytorch/language-modeling/run_clm.py:666)
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
pt-xla-profiler: CompileTime too slow: longest instance took 13m35s865ms460.010us. Please open a GitHub issue with the graph dump for our team to optimize.
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:================================================================================

Compilation is extremely slow for this model. You can find the dumps attached. dumps.zip

Hope you can give us a hand. I'll follow up with some more info about profiling and performance issue later.

Let me know if you have any questions and / or if you need more info to help troubleshoot the problem.