pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

Custom learning rate scheduler affects TPU performance #4083

Open DanielRoeder1 opened 2 years ago

DanielRoeder1 commented 2 years ago

❓ Questions and Help

I have trained my transformer model once on a single GPU and once using a multi-core TPU. In both cases a batchsize of 256 is used (times 8 for the TPU). My training results show that the TPU loss after 400 update steps almost equals the GPU loss after 400 updates even though the effective batchsize is 8*times as high. This leads me to believe that the TPU cores are somehow misaligned thus each training their own model (This trend continues). I use a custom learning rate scheduler to update the LR at each training step, see Train.py, Scheduler. If I remove this scheduler the training loss during TPU training drops significantly faster but the training becomes very unstable.

In the training loop the optimizer is initialized for each core and as part of the Scheduler which updates the learning rate before each training step:

def map_fn(index, flags):
  torch.manual_seed(flags['seed'])
  device = xm.xla_device()  
  model = WRAPPED_MODEL.to(device).train()
  scheduler = Scheduler(Adam(model.parameters(), betas=(0.9,0.98), eps= 10e-9),config) #<-----
  loss_fn = torch.nn.CrossEntropyLoss(ignore_index= config.pad_idx)

  train_sampler = torch.utils.data.distributed.DistributedSampler(
    dataset["train"],
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)

  train_loader = torch.utils.data.DataLoader(
      dataset["train"],
      batch_size=flags['batch_size'],
      sampler=train_sampler,
      num_workers=flags['num_workers'],
      drop_last=True)

  def train_epoch(loader):
    model.train()
    for batch_num, batch in enumerate(loader):
      src_input, trgt_seq = batch["input_ids"], batch["labels"]
      trgt_input = trgt_seq[:,:-1]
      trgt_label = trgt_seq[:,1:]
      scheduler.optimizer.zero_grad() #<-----
      pred = model(src_input, trgt_input)
      loss = loss_fn(pred.transpose(1,2), trgt_label)
      loss.backward()

      scheduler.update_learning_rate() #<-----
      xm.optimizer_step(scheduler.optimizer) #<-----

      if batch_num % flags['log_steps'] == 0:
        xm.master_print(f'[{batch_num}/ {len(loader)}] Loss={loss} Time={get_time()}')

  for epoch in range(flags['num_epochs']):
    para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
    train_epoch(para_train_loader)
    xm.master_print("Finished training epoch {}".format(epoch))

Any help in mitigating the performance difficulties encountered when using the scheduler is more than welcome! Thanks

Extra Information

Model: "Attention Is All You Need Transformer" (self-coded)

Environment: Colab TPUv2 using torch xla 1.12, Colab GPU T4 (non xla torch)

Train settings: Batch size 256, same LR schedule, same loss function (CrossEntropy), TPU uses 8 cores so 8 * 256 batch

Data: WMT14 4.5million sentences de-en

JackCaoG commented 2 years ago

@AlexWertheim do you have cycle to take a look at this one?

AlexWertheim commented 2 years ago

Sure, I can take a look!

miladm commented 2 years ago

@DanielRoeder1

Thanks for sharing your model bug with us. Can you please provide the following details so we can reproduce exactly what you experienced on your end. This way we will be able to circle back with concrete steps you can take to improve the problem.

Thanks.

DanielRoeder1 commented 2 years ago

Sure, excuse the late response. You can find the complete training code in the following colab notebooks:

TPU: https://colab.research.google.com/drive/1fSTCbKq7b2iYaDQwrkVe18E81qDZdt3N?usp=sharing

GPU: https://colab.research.google.com/drive/1hW9_pr4B1yDI9sfMs8DRyGUFYQXkybft?usp=sharing

The hyperparameters are the same between both notebooks. The majority of parameters is set in the config.json