allenai / allennlp

An open-source NLP research library, built on PyTorch.
http://www.allennlp.org
Apache License 2.0
11.77k stars 2.25k forks source link

Multi-GPU training hangs #5088

Closed aleSuglia closed 3 years ago

aleSuglia commented 3 years ago

Checklist

Description

I am trying to run multi-GPU training (using 4 GPUs) but it hangs after a few iterations (roughly 15 iterations). This happens both with my custom model as well as with models in allennlp-models (I tried roberta-large).

Related issues or possible duplicates

Environment

OS: Deep Learning AMI (Ubuntu 18.04) Version 42.1 -- AWS EC2 p3.8xlarge

Python version: Python 3.8 installed via Anaconda

Steps to reproduce

I have installed allennlp-models and changed the configuration file reported above as follows:

local transformer_model = "roberta-base";
local transformer_dim = 768;

{
  "dataset_reader":{
    "type": "boolq",
    "token_indexers": {
      "tokens": {
        "type": "pretrained_transformer",
        "model_name": transformer_model,
      }
    },
    "tokenizer": {
      "type": "pretrained_transformer",
      "model_name": transformer_model,
    }
  },
  "train_data_path": "https://storage.googleapis.com/allennlp-public-data/BoolQ.zip!BoolQ/train.jsonl",
  "validation_data_path": "https://storage.googleapis.com/allennlp-public-data/BoolQ.zip!BoolQ/val.jsonl",
  "test_data_path": "https://storage.googleapis.com/allennlp-public-data/BoolQ.zip!BoolQ/test.jsonl",
  "model": {
    "type": "basic_classifier",
    "text_field_embedder": {
      "token_embedders": {
        "tokens": {
          "type": "pretrained_transformer",
          "model_name": transformer_model,
        }
      }
    },
    "seq2vec_encoder": {
       "type": "bert_pooler",
       "pretrained_model": transformer_model,
       "dropout": 0.1,
    },
    "namespace": "tags",
    "num_labels": 2,
  },
  "data_loader": {
    "batch_sampler": {
      "type": "bucket",
      "sorting_keys": ["tokens"],
      "batch_size" : 4
    }
  },
  "distributed": {
      "cuda_devices": [0,1,2,3]
  },
  "trainer": {
    "num_epochs": 10,
    "num_gradient_accumulation_steps": 2,
    "validation_metric": "+accuracy",
    "learning_rate_scheduler": {
      "type": "slanted_triangular",
      "num_epochs": 10,
      "num_steps_per_epoch": 3088,
      "cut_frac": 0.06
    },
    "optimizer": {
      "type": "huggingface_adamw",
      "lr": 1e-5,
      "weight_decay": 0.1,
    }
  },
}

@epwalsh Any ideas?

epwalsh commented 3 years ago

What version of AllenNLP?

aleSuglia commented 3 years ago

Last one on the branch main.

epwalsh commented 3 years ago

So far I haven't been able to reproduce, but I've only ran it on CPU (I set "cuda_devices": [-1, -1, -1, -1]).

aleSuglia commented 3 years ago

Yeah I think there is something weird going on the GPU.

Sorry ignore my comment on gradient accumulation. Happens also without!

epwalsh commented 3 years ago

Can you also reproduce it on 2 GPUs? It would be much easier for me to debug on just 2 GPUs.

dirkgr commented 3 years ago

I just reproed a hang while testing https://github.com/allenai/allennlp/pull/5077, so now I'm trying it without those changes to see if it still hangs.

dirkgr commented 3 years ago

I'm seeing a hang when training the PIQA model. Happens at the end of the first epoch. The GPUs stay busy though. That's a little suspicious. With a proper hang, they should go to 0%.

dirkgr commented 3 years ago

GPUs stay busy at 100% when this hang happens. No deviation from 100% at all.

aleSuglia commented 3 years ago

Exactly what happens in my case as well. How can I help debugging this? @epwalsh I see the same behaviour when using 2 GPUs (on the same instance).

epwalsh commented 3 years ago

@dirkgr were training on a single GPU or multiple?

aleSuglia commented 3 years ago

@epwalsh I think it's an issue with the multiple GPU training setup!

dirkgr commented 3 years ago

Yes, multiple GPUs. When the GPU hangs at 100% utilization it's usually a CUDA problem, not an AllenNLP problem. But this seems pretty serious, so regardless of whose problem it is, we need a solution. It would be interesting to see if this happens with older CUDA versions as well. The other thing that would really help is a stack trace of where it's hanging. I could not get one when I ran it on Thursday.

dirkgr commented 3 years ago

With "max_instances": 2000, it works. With "max_instances": 2001, it hangs. That looks like an AllenNLP problem then.

dirkgr commented 3 years ago

Looks like this was already broken in 2.0.0 😢. I feel a new test case coming on.

dirkgr commented 3 years ago

This suggests a terrible workaround: Make sure the number if instances is an even multiple of the batch size.

We'll make a release as soon as this is fixed. This is really bad.

dirkgr commented 3 years ago

I finally got a stack trace from a worker:

######### ProcessID=82627, ThreadID=4617211392 #########
File: "<string>", line 1, in <module>
File: "/Users/dirkg/miniconda3/envs/allennlp-models/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
  exitcode = _main(fd, parent_sentinel)
File: "/Users/dirkg/miniconda3/envs/allennlp-models/lib/python3.8/multiprocessing/spawn.py", line 129, in _main
  return self._bootstrap(parent_sentinel)
File: "/Users/dirkg/miniconda3/envs/allennlp-models/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
  self.run()
File: "/Users/dirkg/miniconda3/envs/allennlp-models/lib/python3.8/multiprocessing/process.py", line 108, in run
  self._target(*self._args, **self._kwargs)
File: "/Users/dirkg/miniconda3/envs/allennlp-models/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
  fn(i, *args)
File: "/Users/dirkg/Documents/allennlp/allennlp/commands/train.py", line 469, in _train_worker
  metrics = train_loop.run()
File: "/Users/dirkg/Documents/allennlp/allennlp/commands/train.py", line 531, in run
  return self.trainer.train()
File: "/Users/dirkg/Documents/allennlp/allennlp/training/trainer.py", line 735, in train
  metrics, epoch = self._try_train()
File: "/Users/dirkg/Documents/allennlp/allennlp/training/trainer.py", line 767, in _try_train
  train_metrics = self._train_epoch(epoch)
File: "/Users/dirkg/Documents/allennlp/allennlp/training/trainer.py", line 503, in _train_epoch
  batch_outputs = self.batch_outputs(batch, for_training=True)
File: "/Users/dirkg/Documents/allennlp/allennlp/training/trainer.py", line 391, in batch_outputs
  output_dict = self._pytorch_model(**batch)
File: "/Users/dirkg/miniconda3/envs/allennlp-models/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
  result = self.forward(*input, **kwargs)
File: "/Users/dirkg/miniconda3/envs/allennlp-models/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 610, in forward
  self._sync_params()
File: "/Users/dirkg/miniconda3/envs/allennlp-models/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1045, in _sync_params
  self._distributed_broadcast_coalesced(
File: "/Users/dirkg/miniconda3/envs/allennlp-models/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 978, in _distributed_broadcast_coalesced
  dist._broadcast_coalesced(

So we're stuck inside of pytorch. 😳

dirkgr commented 3 years ago

Ah, but the other worker is not. I think I know what's going on.

dirkgr commented 3 years ago

I think #5100 fixes this. @aleSuglia, can you confirm that you were using gradient accumulation?

dirkgr commented 3 years ago

Yes, your training config says so. Can you try that fix and see if it works for you?

aleSuglia commented 3 years ago

@dirkgr Mmmm unfortunately it doesn't work for me. The only thing that my custom trainer does is the following:

        done_early = False
        for batch_group in batch_group_generator_tqdm:
            if done_early:
                break

            # Zero gradients.
            # NOTE: this is actually more efficient than calling `self.optimizer.zero_grad()`
            # because it avoids a read op when the gradients are first updated below.
            for param_group in self.optimizer.param_groups:
                for p in param_group["params"]:
                    p.grad = None

            batch_loss = 0.0
            batch_group_outputs = []

            for batch in batch_group:
                if self._distributed:
                    # Check whether the other workers have stopped already (due to differing amounts of
                    # data in each). If so, we can't proceed because we would hang when we hit the
                    # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor
                    # here because NCCL process groups apparently don't support BoolTensor.
                    done = torch.tensor(0, device=self.cuda_device)
                    torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM)
                    if done.item() > 0:
                        done_early = True
                        logger.warning(
                            f"Worker {torch.distributed.get_rank()} finishing training early! "
                            "This implies that there is an imbalance in your training "
                            "data across the workers and that some amount of it will be "
                            "ignored. A small amount of this is fine, but a major imbalance "
                            "should be avoided. Note: This warning will appear unless your "
                            "data is perfectly balanced."
                        )
                        break

                if self.num_bptt_steps > 0:
                    splits = split_data_batch(batch, ignore_keys=self.ignore_keys, split_size=self.num_bptt_steps)
                else:
                    splits = [batch]

                hidden_states = None

                for split_batch in splits:
                    batches_this_epoch += 1
                    self._batch_num_total += 1
                    with amp.autocast(self._use_amp):
                        batch_outputs = self.batch_outputs(split_batch, hidden_states=hidden_states, for_training=True)
                        batch_group_outputs.append(batch_outputs)
                        loss = batch_outputs["loss"]
                        hidden_states = batch_outputs["hidden_states"]
                        reg_loss = batch_outputs.get("reg_loss")
                        if torch.isnan(loss):
                            raise ValueError("nan loss encountered")
                        loss = loss / len(batch_group)

                        batch_loss += loss.item()
                        if reg_loss is not None:
                            reg_loss = reg_loss / len(batch_group)
                            batch_reg_loss = reg_loss.item()
                            train_reg_loss += batch_reg_loss  # type: ignore

                        if self._scaler is not None:
                            self._scaler.scale(loss).backward()
                        else:
                            loss.backward()

                    train_loss += batch_loss
                    batch_loss /= len(splits)

            if len(batch_group_outputs) <= 0:
                continue

In my case though the training gets stuck when is about to start the first epoch. It loads max_instances_in_memory datapoints and then it gets stuck...

This is the stacktrace if I press CTRL+C:

^CTraceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/embai/bin/allennlp", line 8, in <module>
    sys.exit(run())
  File "/home/ubuntu/anaconda3/envs/embai/lib/python3.8/site-packages/allennlp/__main__.py", line 34, in run
    main(prog="allennlp")
  File "/home/ubuntu/anaconda3/envs/embai/lib/python3.8/site-packages/allennlp/commands/__init__.py", line 119, in main
    args.func(args)
  File "/home/ubuntu/anaconda3/envs/embai/lib/python3.8/site-packages/allennlp/commands/train.py", line 110, in train_model_from_args
    train_model_from_file(
  File "/home/ubuntu/anaconda3/envs/embai/lib/python3.8/site-packages/allennlp/commands/train.py", line 170, in train_model_from_file
    return train_model(
  File "/home/ubuntu/anaconda3/envs/embai/lib/python3.8/site-packages/allennlp/commands/train.py", line 308, in train_model
    mp.spawn(
  File "/home/ubuntu/anaconda3/envs/embai/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 199, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/ubuntu/anaconda3/envs/embai/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
    while not context.join():
  File "/home/ubuntu/anaconda3/envs/embai/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 75, in join
    ready = multiprocessing.connection.wait(
  File "/home/ubuntu/anaconda3/envs/embai/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/home/ubuntu/anaconda3/envs/embai/lib/python3.8/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt
^CError in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/embai/lib/python3.8/multiprocessing/popen_fork.py", line 27, in poll
    pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt

I see something else that is very weird. There are multiple CUDA context processes on each GPU. Instead of having just a single process for each GPU. They occupy roughly 350Mb each. On each GPU I therefore have 4 CUDA processes. Please see status before the training starts:

image

This is when the instances have been loaded and the training actually "starts":

image

dirkgr commented 3 years ago

I can't tell from the stack trace what's going on, but your custom trainer still has the problem I fixed in the PR: You might end up calling model.forward() a different number of times in the workers. It's important to call model.forward() exactly the same number of times in each worker, otherwise they go out of sync.

FWIW, I think this is a terrible programming model, and I hope that the DeepSpeed work we're doing will allow us to move away from multi-process training completely. But for now, we need this one to work.

dirkgr commented 3 years ago

Are you also using multi-process data loading?

aleSuglia commented 3 years ago

No I have set num_workers to 0 and start_method to spawn.

aleSuglia commented 3 years ago

I can't tell from the stack trace what's going on, but your custom trainer still has the problem I fixed in the PR: You might end up calling model.forward() a different number of times in the workers. It's important to call model.forward() exactly the same number of times in each worker, otherwise they go out of sync.

Unfortunately, due to the fact I'm splitting batches along the time dimension using the function split_data_batch (for efficiency reasons), I will inevitably end-up with a different overall number of batches per worker. Would your fix apply to this case as well?

dirkgr commented 3 years ago

Would your fix apply to this case as well?

No, but you can use the idea of your fix in your custom trainer. Make sure that the check whether we're done (the one I moved in my PR) happens inside your for split_batch in splits loop. That means that you might process a partial batch.

The other way to circumvent this is to set drop_last on the data loader. That way you know you'll always have complete batches.

I am wondering if you could have solved this problem without resorting to splitting batches at the last second. TransformerQA strides over the text sequence if it is too long. Predictions are then only for one "chunk", and to get a real final evaluation over a sequence that is too long, you have to run the predictor, which combines multiple calls to the model into a single prediction. It's a bit messy, but less messy than hacking the trainer.

dirkgr commented 3 years ago

If you're not using multiple processes for data loading, and you didn't hack some other parts of the data loading pipeline, I can't think of anything that would spawn those extra processes. @epwalsh, can you think of something?

epwalsh commented 3 years ago

If num_workers is 0, the data loader does not spawn any workers. From your stack trace, it looks like the only spawning going on is the spawning of the distributed trainer workers. Am I missing something?

dirkgr commented 3 years ago

His nvidia-smi trace shows 4 processes per GPU.

Though actually, that's not what it shows. @aleSuglia, it shows that you have four processes, and each of them connects to every GPU. Though only the one actually doing work needs any real memory. I can't explain that, but I also don't think it's important.

aleSuglia commented 3 years ago

No I haven't changed the data loading procedure. I have only extended the MultiProcessDataLoader in order to change its collate function. The core functionalities are the same. It looks like there are multiple CUDA context resulting from additional memory allocation steps. I'll try and debug it a little bit more and let you know!

OhadRubin commented 3 years ago

Hey, what is the status of this issue? I've been experiencing this as well.

dirkgr commented 3 years ago

Sorry for the delay! Can you make a new issue explaining what's going on? Everything in this issue is a bit dated now and the code has changed since then.