huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.89k stars 26.51k forks source link

an inplace operation preventing TorchDistributor training #25130

Open liqi6811 opened 1 year ago

liqi6811 commented 1 year ago

System Info

databricks

Who can help?

@ArthurZucker @younesbelkada

Hi team,

I got an error message by using TorchDistributor.

I have checked in the class BertEmbeddings (url as below), line 238, embeddings += position_embeddings is an inplace operation, would you be able to change to embeddings = embeddings + position_embeddings, to allow TOrchDistributor?

BertEmbeddings url: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py

TorchDistributor sample code: https://docs.databricks.com/_extras/notebooks/source/deep-learning/torch-distributor-notebook.html

Thank you very much! Ling

Information

Tasks

Reproduction

single_node_single_gpu_dir = create_log_dir()
print("Data is located at: ", single_node_single_gpu_dir)

def train_one_epoch(model, device, data_loader, optimizer, epoch):
  torch.autograd.set_detect_anomaly(True)
  model.train()
  for batch_idx, (data, labels) in enumerate(data_loader):
    inputs1, inputs2 = data[0], data[1]
    inputs1 = {key: val.to(device) for key, val in inputs1.items()}
    inputs2 = {key: val.to(device) for key, val in inputs2.items()}
    # labels = labels.float().to(device)
    labels = labels.to(device)

    optimizer.zero_grad()
    # Compute embeddings
    embeddings1 = model(inputs1)['sentence_embedding']
    embeddings2 = model(inputs2)['sentence_embedding']

    # Compute loss
    loss = cosine_similarity_loss(embeddings1, embeddings2, labels)

    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
          epoch, batch_idx * len(data), len(data_loader) * len(data),
          100. * batch_idx / len(data_loader), loss.item()))

      if int(os.environ["RANK"]) == 0:
        mlflow.log_metric('train_loss', loss.item())

def save_checkpoint(log_dir, model, optimizer, epoch):
  filepath = log_dir + '/checkpoint-{epoch}.pth.tar'.format(epoch=epoch)
  state = {
    'model': model.module.state_dict(),
    'optimizer': optimizer.state_dict(),
  }
  torch.save(state, filepath)

# For distributed training we will merge the train and test steps into 1 main function
def main_fn(directory):

  #### Added imports here ####
  import mlflow
  import torch.distributed as dist
  from torch.nn.parallel import DistributedDataParallel as DDP
  from torch.utils.data.distributed import DistributedSampler

  ############################

  ##### Setting up MLflow ####
  # We need to do this so that different processes that will be able to find mlflow
  os.environ['DATABRICKS_HOST'] = db_host
  os.environ['DATABRICKS_TOKEN'] = db_token

  # We set the experiment details here
  experiment = mlflow.set_experiment(experiment_path)
  ############################

  print("Running distributed training")
  dist.init_process_group("nccl")

  local_rank = int(os.environ["LOCAL_RANK"])
  global_rank = int(os.environ["RANK"])

  if global_rank == 0:
    train_parameters = {'batch_size': batch_size, 'epochs': num_epochs, 'trainer': 'TorchDistributor'}
    mlflow.log_params(train_parameters)

  model = SentenceTransformer(modelname)

  filepath = "../../dbfs/mnt/path2data/"
  df_train = readData('train', filepath)
  df_train = df_train.head(10000)

  train_text = df_train[['sentA', 'sentB', 'score']].values.tolist()
  train_examples = [InputExample(texts=[a, b], label=s) for [a, b, s] in train_text]
  train_dataset = SentencesDataset(train_examples, model)
  #### Added Distributed Dataloader ####
  train_sampler = DistributedSampler(dataset=train_dataset)
  data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
  ######################################
  data_loader.collate_fn = model.smart_batching_collate

  model = model.to(local_rank)
  #### Added Distributed Model ####
  ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)
  #################################

  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

  for epoch in range(1, num_epochs + 1):
    train_one_epoch(ddp_model, local_rank, data_loader, optimizer, epoch)

    if global_rank == 0: 
      save_checkpoint(directory, ddp_model, optimizer, epoch)

  dist.destroy_process_group()

  return "finished" # can return any picklable object

# single node distributed run to quickly test that the whole process is working
with mlflow.start_run():
  mlflow.log_param('run_type', 'test_dist_code')
  main_fn(single_node_single_gpu_dir)

Expected behavior

below error disappear.

image

ydshieh commented 1 year ago

I would be very surprised if this famous BERT model has such issue.

Could you provide the system environment like pytorch version.

You can run the command transformers-cli env and copy-paste its output.

ArthurZucker commented 1 year ago

Actually @ydshieh I think this is pretty valid, and we have a bunch of issues with inplace operations preventing fsdp training. This is not limited to the embedding, have seen other places where the code fails. See the linked issue for more details.

ydshieh commented 1 year ago

@ArthurZucker Thanks. I know there is such problem, like I have engaged in #24525.

My main concern here: is this issue (for BERT) is only happening with TorchDistributor (or FSDP as you said). In #24525, it seems it happens without these other tools. And BERT exists for so long, so I am somehow confused about what exactly triggers this error.

liqi6811 commented 1 year ago

@ydshieh system environment is below:

liqi6811 commented 1 year ago

@ydshieh @ArthurZucker I am working in Azure Databricks, I used Horovod for distributed training, the inplace operation does not cause any issue, but Horovod 4GPU is only 1.6 times faster than 1GPU. TorchDistributor can be nearly 4 times faster. However, TorchDistributor does not work due to inplace opertaion. I tried subclassing to remove inplace operations, but not easy :). Hopefully you guys can help to release an update. Thanks a lot.

liqi6811 commented 1 year ago

@ydshieh @ArthurZucker I would suggest to do a thorough check for all inplace operations, and get rid of all :).