Training works with nprocs = 1 but not nprocs = 8 #2937

gabrielwong1991 commented 3 years ago

❓ Questions and Help

Hi all, I am training a transformer model with custom dataset (~ 800mb data) in Colab (or Kaggle) TPU and my code works when setting nprocs = 1 but will hang when after reaching xm.mark_step() in first step and entering second step.

Unfortunately, I cannot run met.metrics_report() since it won't reach the first step though I could produce when nprocs = 1.

I tried troubleshooting by checking whether my computation graph randomly changes such that XLA will need to recompile before feeding into TPU devices. The model is pretty standard and other than dropout layer I don't see any problem at all. I also check pad all sentences to the same length such that their input tensor shape doesn't change in each batch. Also I make sure I do not use .item(), if or any so tensor don't need to passing back and forth etc...

Below are the codes related to the problem at hand, I believe the data loading part is fine as I am able to train when nprocs = 8?

The computation graph, pretty standard I just modify from some tutorials:

class TweetsClass(torch.nn.Module):
    def __init__(self):
        super(TweetsClass, self).__init__()
        self.l1 = tfm.AutoModel.from_pretrained("vinai/bertweet-base")
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(768, NUM_LABELS)

    def forward(self, input_ids, attention_mask, token_type_ids):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.ReLU()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output

The training loop, also pretty standard (print the steps to see where the loop reaches when training):

## Define how loss is averaged out of the 8 TPUs
def reduce_fn(vals):
  # take average
  return sum(vals) / len(vals)   

# Define training loop function
def train_loop_fn(data_loader, model, optimizer, device, scheduler = None):
    tracker = xm.RateTracker()
    model.train() # Put model to training mode
    for bi, data in enumerate(data_loader):
        start_time = time.time()

        print("Extract data")
        # Extract data
        ids = data['ids'].to(device, dtype = torch.long)
        mask = data['mask'].to(device, dtype = torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.long)
        # Reset the gradent
        print("Zero Grad")

        # Pass ids, mask, token_type_ids to model 
        outputs = model(ids, mask, token_type_ids)

        # Create loss function (Cross Entropy loss for multi-label classification) and optimizer (using Adam optimizer)
        loss_function = torch.nn.CrossEntropyLoss()
        loss = loss_function(outputs, targets)   

        # Print every 20 steps
        if bi%20==0:
           # since the loss is on all 8 cores, reduce the loss values and print the average (as defined in reduce_fn)
           print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
           xm.get_ordinal(), bi, loss.item(), tracker.rate(),
           tracker.global_rate(), time.asctime()), flush=True)

        # loss_reduced = xm.mesh_reduce('loss_reduce',loss,reduce_fn) 
        # # master_print will only print once (not from all 8 cores)
        # xm.master_print(f'bi={bi}, loss={loss_reduced}')

        # Backprop ninja

        print("Step Optimizer")
        # Use PyTorch XLA optimizer stepping

        if scheduler is not None:

        # Update model params in XLA tensor
        print("Mark Step")

        end_time = time.time()
        print(f"Time for steps {bi}: {end_time - start_time}")

    # Set model to evaluation mode

The main map function:

def map_fn(index, flags):
    # Sets a common random seed - both for initialization and ensuring graph is the same

    # Acquires the (unique) Cloud TPU core corresponding to this process's index
    device = xm.xla_device()  

    # Use one instance to download datasets 
    if not xm.is_master_ordinal():

    train_dataset = pd.read_csv("train_set.csv")
    val_dataset =  pd.read_csv("dev_set.csv")

    if not xm.is_master_ordinal():

    tokenizer = tfm.AutoTokenizer.from_pretrained("vinai/bertweet-base", use_fast=False, return_tensors='pt')
    # Use dataloader #
    train_set = TweetsData(train_dataset, tokenizer, MAX_LEN)
    val_set = TweetsData(val_dataset, tokenizer, MAX_LEN)

    # Training dataset loader #
    # Wrap our Class imbalance Sampler with DistributedSamplerWrapper
    train_sampler = DistributedSamplerWrapper(
        sampler=BalanceClassSampler(labels=train_dataset.label.values, mode='upsampling'),

    train_loader = DataLoader(train_set,

    val_loader = DataLoader(val_set,

    # Push our neural network to TPU 
    model = TweetsClass()

    # Don't decay normalized layer
    param_optimizer = list(model.named_parameters()) # model parameters to optimize
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    # apply to weight decay
    optimizer_grouped_parameters = [
      {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
      {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    # Create loss function (Cross Entropy loss for multi-label classification) and optimizer (using Adam optimizer)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = AdamW(params = optimizer_grouped_parameters , lr = LEARNING_RATE * xm.xrt_world_size())

    # Create number of training steps
    num_train_steps = int(len(train_dataset) / TRAIN_BATCH_SIZE / xm.xrt_world_size() * EPOCHS) 

    # Scheduler for optimizer for learning decay
    scheduler = get_linear_schedule_with_warmup(

    xm.master_print(f"Train for {len(train_dataset)} steps per epoch")
    xm.master_print(f'num_training_steps = {num_train_steps}, world_size={xm.xrt_world_size()}')

    for epoch in range(EPOCHS):
        xm.master_print(f"Starting training in epoch: {epoch}")

        ## Training Part ##
        xm.master_print("Entering training loop")
        para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
        # Call Training Loop
        train_loop_fn(para_train_loader, model, optimizer, device, scheduler=scheduler)
        del para_train_loader

Calling XMP.SPAWN:

MAX_LEN = 128
flags = {}

xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')

The outputs:

num_training_steps = 49368, world_size=8 Starting training in epoch: 0 Entering training loop Start Extract data Zero Grad Model Loss xla:0 Loss=1.10938 Rate=0.00 GlobalRate=0.00 Time=Mon May 10 16:42:54 2021 Backward Step Optimizer Scheduler Mark Step Time for steps 0: 63.811904430389404

It hangs and will not enter the second step, i.e. bi = 2 any ideas?

Many thanks.

gabrielwong1991 commented 3 years ago

In addition, I generated a metric report when using nprocs = 1 and pause at steps 8000. Most metrics are okay (coincide with the page) - not enough to hang forever when using 8 TPU cores...

jysohn23 commented 3 years ago

Like you say on single core it looks like training is working fine and performing well.

On the 8 core setup, could you try re-running with the following environment variables?

PT_XLA_DEBUG=1 TF_CPP_LOG_THREAD_ID=1 TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=tensor=5,computation_client=5,xrt_computation_client=5,aten_xla_type=1

That should tell us exactly what is happening within our framework when it's hanging.

Also I noticed that data_loader in your train loop fn is already a pl.ParallelLoader you don't need to manually do:

gabrielwong1991 commented 3 years ago

Many thanks, @jyshon23! I implement your suggestion and as a baseline, I confirmed working when nprocs = 1. However nprocs = 8 is still hanging. Though this time removing xm.mark_step() allows it to move on to second step but stuck at calling the model outputs = model(ids, mask, token_type_ids)

I expect by enabling the environment variables there will be some logs coming out from the interpreter?




gabrielwong1991 commented 3 years ago

Hi! I have finally able to pull information out from enabling the environment variables by running the script in .py instead of running in .ipynb like: !PT_XLA_DEBUG=1 TF_CPP_LOG_THREAD_ID=1 TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=tensor=5 computation_client=5 xrt_computation_client=5 aten_xla_type=1 python -u

Here are the results though it is not informative?

2021-05-17 23:29:57.812187: I    3547 tensorflow/stream_executor/platform/default/] Successfully opened dynamic library
WARNING:root:TPU has started up successfully with version pytorch-1.8.1
2021-05-17 23:30:00.692445: I    3560 tensorflow/compiler/xla/xla_client/] Waiting to connect to client mesh master (300 seconds) b0f6c9bce863:53413
2021-05-17 23:30:00.706905: I    3564 tensorflow/compiler/xla/xla_client/] Waiting to connect to client mesh master (300 seconds) b0f6c9bce863:53413
2021-05-17 23:30:00.713162: I    3570 tensorflow/compiler/xla/xla_client/] Waiting to connect to client mesh master (300 seconds) b0f6c9bce863:53413
2021-05-17 23:30:00.724224: I    3576 tensorflow/compiler/xla/xla_client/] Waiting to connect to client mesh master (300 seconds) b0f6c9bce863:53413
2021-05-17 23:30:00.734355: I    3580 tensorflow/compiler/xla/xla_client/] Waiting to connect to client mesh master (300 seconds) b0f6c9bce863:53413
2021-05-17 23:30:00.744493: I    3584 tensorflow/compiler/xla/xla_client/] Waiting to connect to client mesh master (300 seconds) b0f6c9bce863:53413
2021-05-17 23:30:00.754056: I    3588 tensorflow/compiler/xla/xla_client/] Waiting to connect to client mesh master (300 seconds) b0f6c9bce863:53413
2021-05-17 23:30:25.099310: I    3576 tensorflow/compiler/xla/xla_client/] Fetching mesh configuration for worker tpu_worker (host_ordinal=0):0 from mesh service at b0f6c9bce863:53413
2021-05-17 23:30:26.269124: I    3570 tensorflow/compiler/xla/xla_client/] Fetching mesh configuration for worker tpu_worker (host_ordinal=0):0 from mesh service at b0f6c9bce863:53413
2021-05-17 23:30:26.974932: I    3564 tensorflow/compiler/xla/xla_client/] Fetching mesh configuration for worker tpu_worker (host_ordinal=0):0 from mesh service at b0f6c9bce863:53413
2021-05-17 23:30:27.845523: I    3584 tensorflow/compiler/xla/xla_client/] Fetching mesh configuration for worker tpu_worker (host_ordinal=0):0 from mesh service at b0f6c9bce863:53413
2021-05-17 23:30:28.775142: I    3588 tensorflow/compiler/xla/xla_client/] Fetching mesh configuration for worker tpu_worker (host_ordinal=0):0 from mesh service at b0f6c9bce863:53413
2021-05-17 23:30:28.874595: I    3560 tensorflow/compiler/xla/xla_client/] Fetching mesh configuration for worker tpu_worker (host_ordinal=0):0 from mesh service at b0f6c9bce863:53413
2021-05-17 23:30:29.900606: I    3580 tensorflow/compiler/xla/xla_client/] Fetching mesh configuration for worker tpu_worker (host_ordinal=0):0 from mesh service at b0f6c9bce863:53413
2021-05-17 23:30:37.928487: I    3559 tensorflow/core/profiler/rpc/] Profiler server listening on [::]:9012 selected port:9012
2021-05-17 23:30:46.457900: I    3559 torch_xla/csrc/tensor_util.cpp:28] Using BF16 data type for floating point values
Train for 12638343 steps per epoch
num_training_steps = 49368, world_size=8
Starting training in epoch: 0
Entering training loop
2021-05-17 23:30:51.221439: I    3559 torch_xla/csrc/tensor.cpp:1373] 204 live tensors: devices=()
2021-05-17 23:30:51.221548: I    3559 torch_xla/csrc/tensor.cpp:1122] Waiting on device barrier for device TPU:0 ...
2021-05-17 23:30:51.221591: I    3559 torch_xla/csrc/tensor.cpp:1128] Waiting on device barrier for device TPU:0 done!
2021-05-17 23:30:51.272911: I    3559 torch_xla/csrc/tensor.cpp:1168] Tensors graph hash f8539b4c16d32784d76c0592dedb929d on device TPU:0
2021-05-17 23:30:51.274962: I    3559 torch_xla/csrc/tensor.cpp:1592] Parameter sequence graph hash e9be56370f81f759e99c0d5c7da37fa2
2021-05-17 23:30:51.357767: I    3559 torch_xla/csrc/tensor.cpp:1556] Compiling IR graph hash e9be56370f81f759e99c0d5c7da37fa2 on device TPU:0 ...
2021-05-17 23:30:51.404462: I    3559 torch_xla/csrc/tensor.cpp:1561] Compiling IR graph hash e9be56370f81f759e99c0d5c7da37fa2 on device TPU:0 done!
2021-05-17 23:30:51.405821: I    3559 torch_xla/csrc/tensor.cpp:1563] Graph hash e9be56370f81f759e99c0d5c7da37fa2 is computation hash e622f901bc78b50e06778eb7026a41da
2021-05-17 23:30:51.406392: I    3559 torch_xla/csrc/tensor.cpp:1602] TensorsGraphSize=1020
2021-05-17 23:30:51.410112: I    3693 torch_xla/csrc/tensor.cpp:1298] Executing IR graph hash e9be56370f81f759e99c0d5c7da37fa2 on device TPU:0 ...
2021-05-17 23:30:51.419282: I    3693 torch_xla/csrc/tensor.cpp:1303] Executing IR graph hash e9be56370f81f759e99c0d5c7da37fa2 on device TPU:0 done!
Extract data
Zero Grad
2021-05-17 23:31:24.800329: I    3559 torch_xla/csrc/tensor.cpp:1122] Waiting on device barrier for device TPU:0 ...
2021-05-17 23:31:24.800419: I    3559 torch_xla/csrc/tensor.cpp:1128] Waiting on device barrier for device TPU:0 done!
2021-05-17 23:31:24.800470: I    3559 torch_xla/csrc/tensor.cpp:1168] Tensors graph hash 2c69fbdc154c900d1b019feaa507cd74 on device TPU:0
2021-05-17 23:31:24.801801: I    3559 torch_xla/csrc/tensor.cpp:1592] Parameter sequence graph hash 8e5c6ca6d938cff3d4dcb32fbcffb22e
2021-05-17 23:31:24.914196: I    3559 torch_xla/csrc/tensor.cpp:1556] Compiling IR graph hash 8e5c6ca6d938cff3d4dcb32fbcffb22e on device TPU:0 ...
2021-05-17 23:31:24.959115: I    3559 torch_xla/csrc/tensor.cpp:1561] Compiling IR graph hash 8e5c6ca6d938cff3d4dcb32fbcffb22e on device TPU:0 done!
2021-05-17 23:31:24.961928: I    3559 torch_xla/csrc/tensor.cpp:1563] Graph hash 8e5c6ca6d938cff3d4dcb32fbcffb22e is computation hash 986ddacef03810dea5d00e9ec233f56
2021-05-17 23:31:24.962524: I    3559 torch_xla/csrc/tensor.cpp:1602] TensorsGraphSize=2273
2021-05-17 23:31:24.962581: I    3693 torch_xla/csrc/tensor.cpp:1298] Executing IR graph hash 8e5c6ca6d938cff3d4dcb32fbcffb22e on device TPU:0 ...
2021-05-17 23:31:25.021284: I    3693 torch_xla/csrc/tensor.cpp:1303] Executing IR graph hash 8e5c6ca6d938cff3d4dcb32fbcffb22e on device TPU:0 done!
[xla:0](0) Loss=1.28125 Rate=0.00 GlobalRate=0.00 Time=Mon May 17 23:31:25 2021
2021-05-17 23:31:25.059904: I    3559 torch_xla/csrc/tensor.cpp:1373] 439 live tensors: devices=()
2021-05-17 23:31:25.059993: I    3559 torch_xla/csrc/tensor.cpp:1122] Waiting on device barrier for device TPU:0 ...
2021-05-17 23:31:25.060011: I    3559 torch_xla/csrc/tensor.cpp:1128] Waiting on device barrier for device TPU:0 done!
2021-05-17 23:31:25.061081: I    3559 torch_xla/csrc/tensor.cpp:1168] Tensors graph hash bd01710a26633b903f90743d92f0327d on device TPU:0
2021-05-17 23:31:25.062721: I    3559 torch_xla/csrc/tensor.cpp:1592] Parameter sequence graph hash dd5421c93af5a320b6606bceae91e861
2021-05-17 23:31:25.132474: I    3559 torch_xla/csrc/tensor.cpp:1556] Compiling IR graph hash dd5421c93af5a320b6606bceae91e861 on device TPU:0 ...
2021-05-17 23:31:25.206239: I    3559 torch_xla/csrc/tensor.cpp:1561] Compiling IR graph hash dd5421c93af5a320b6606bceae91e861 on device TPU:0 done!
2021-05-17 23:31:25.211813: I    3559 torch_xla/csrc/tensor.cpp:1563] Graph hash dd5421c93af5a320b6606bceae91e861 is computation hash ae93d9c8850271bb1f23a071977519e0
2021-05-17 23:31:25.213132: I    3559 torch_xla/csrc/tensor.cpp:1602] TensorsGraphSize=4522
2021-05-17 23:31:25.216170: I    3693 torch_xla/csrc/tensor.cpp:1298] Executing IR graph hash dd5421c93af5a320b6606bceae91e861 on device TPU:0 ...
Step Optimizer
2021-05-17 23:31:25.384339: I    3693 torch_xla/csrc/tensor.cpp:1303] Executing IR graph hash dd5421c93af5a320b6606bceae91e861 on device TPU:0 done!
Extract data
Zero Grad
2021-05-17 23:31:25.448908: I    3559 torch_xla/csrc/tensor.cpp:1122] Waiting on device barrier for device TPU:0 ...
2021-05-17 23:31:25.448985: I    3559 torch_xla/csrc/tensor.cpp:1128] Waiting on device barrier for device TPU:0 done!
2021-05-17 23:31:25.449014: I    3559 torch_xla/csrc/tensor.cpp:1168] Tensors graph hash c87bb48399127cedece4ca286d0e1ed2 on device TPU:0
2021-05-17 23:31:25.452055: I    3559 torch_xla/csrc/tensor.cpp:1592] Parameter sequence graph hash 83f75c15d76c2047dda16e371c8f7416
2021-05-17 23:31:25.520943: I    3559 torch_xla/csrc/tensor.cpp:1556] Compiling IR graph hash 83f75c15d76c2047dda16e371c8f7416 on device TPU:0 ...
2021-05-17 23:31:25.590626: I    3559 torch_xla/csrc/tensor.cpp:1561] Compiling IR graph hash 83f75c15d76c2047dda16e371c8f7416 on device TPU:0 done!
2021-05-17 23:31:25.597732: I    3559 torch_xla/csrc/tensor.cpp:1563] Graph hash 83f75c15d76c2047dda16e371c8f7416 is computation hash de01864c3fcb3971fbaa97117e5af3c9
2021-05-17 23:31:25.599633: I    3559 torch_xla/csrc/tensor.cpp:1602] TensorsGraphSize=7809
2021-05-17 23:31:25.599766: I    3693 torch_xla/csrc/tensor.cpp:1298] Executing IR graph hash 83f75c15d76c2047dda16e371c8f7416 on device TPU:0 ...

Stuck at second step in loss... Anything I can provide more information? Many thanks

jysohn23 commented 3 years ago

OK, thanks for the info! Looks like the graph execution is hanging as you can see in the last log. For us to be able to dump the graph that is hanging on execution by setting XLA_SAVE_TENSORS_FILE=path/to/file.hlo XLA_SAVE_TENSORS_FMT=hlo XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1.

What is weird though is that I don't see logs printed after "Loss" as I'd expect execution after "Backward" and "Mark Step" are logged.

gabrielwong1991 commented 3 years ago

Hi, this is the graph for the first step since it does not finish the second step I think the graph is not saved.

I am trying to make the script accessible perhaps one could reproduce the problem?


gabrielwong1991 commented 3 years ago

Hi, I have made the notebook accessible it should work straight away. Thanks.

gabrielwong1991 commented 3 years ago

Hi... Is there anything I can do to understand the issue more? Thanks

jysohn23 commented 3 years ago

Hi sorry for the slow response, running it myself currently.

You can also view the turned on debug logging under Runtime > View Runtime Logs: Screen Shot 2021-05-25 at 10 19 40 AM

Also, given that it looks like you're trying to train using HF transformers library have you given the based scripts a try? We test the trainer based examples on a nightly basis so those should be less likely to have breakages. Example command we run nightly testing on for classification finetuning task:

python examples/pytorch/ \
  --num_cores 8 \
  examples/pytorch/text-classification/ \
    --logging_dir=./tensorboard-metrics \
    --task_name MNLI \
    --cache_dir ./cache_dir \
    --do_train \
    --do_eval \
    --num_train_epochs 3 \
    --max_seq_length 128 \
    --learning_rate 3e-5 \
    --output_dir MNLI \
    --overwrite_output_dir \
    --logging_steps 30 \
    --save_steps 3000 \
    --overwrite_cache \
    --tpu_metrics_debug \
    --model_name_or_path roberta-large \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 16
gabrielwong1991 commented 3 years ago

Thanks, Daniel I will have a look at the script. The reason for this script is to learn the underlying nuts and bolts for the task at hand rather than calling higher-level to do the job.

Anyhow, apparently there is something inside TPU that is out of my expertise as this script works with 1 core.

Please, only when you have time, look into this issue as I am curious as to why this happens.

I shall try now. Thanks again.

