Open danpovey opened 3 years ago
this is with mmi_bigram_train.py, with --world-size=1, --local-rank=0
... i.e. the defaults, no extra args.
We might need to make the port configurable. For a quick work-around you can change it here:
https://github.com/k2-fsa/snowfall/blob/master/snowfall/dist.py#L8
yeah I did that.
can't we randomly choose it in a range? or does that create problems for co-ordination?
is the user just supposed to launch multiple copies of the job? Or does it spawn?
We can choose it randomly - although I think with torch.distributed.launch
we'd have to choose it outside of the python script, and with torch.distributed.spawn
we can choose it inside the python script, before spawning jobs. All the spawned tasks would have to see the same address + port. It might be easier to implement with spawn
that @csukuangfj has been using in another PR.
Getting this error:
de-74279-k2-dev-2-0331181900-7b69767657-72fhf:simple_v1: ngpus=2
de-74279-k2-dev-2-0331181900-7b69767657-72fhf:simple_v1: python3 -m torch.distributed.launch --nproc_per_node=$ngpus ./mmi_bigram_train.py --world_size $ngpus &
[1] 1489648
de-74279-k2-dev-2-0331181900-7b69767657-72fhf:simple_v1: *****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your app\
lication as needed.
*****************************************
de-74279-k2-dev-2-0331181900-7b69767657-72fhf:simple_v1: World size: 2 Rank: 0
World size: 2 Rank: 1
de-74279-k2-dev-2-0331181900-7b69767657-72fhf:simple_v1: Traceback (most recent call last):
Traceback (most recent call last):
File "./mmi_bigram_train.py", line 474, in <module>
File "./mmi_bigram_train.py", line 474, in <module>
main()
File "./mmi_bigram_train.py", line 238, in main
main()
File "./mmi_bigram_train.py", line 238, in main
setup_dist(rank=args.local_rank, world_size=args.world_size)
File "/ceph-dan/.local/lib/python3.8/site-packages/snowfall-0.1-py3.8.egg/snowfall/dist.py", line 9, in setup_dist
setup_dist(rank=args.local_rank, world_size=args.world_size)
File "/ceph-dan/.local/lib/python3.8/site-packages/snowfall-0.1-py3.8.egg/snowfall/dist.py", line 9, in setup_dist
File "/ceph-dan/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 442, in init_process_group
File "/ceph-dan/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 442, in init_process_group
barrier()
File "/ceph-dan/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1947, in barrier
barrier()
File "/ceph-dan/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1947, in barrier
work = _default_pg.barrier()
RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:784, invalid usage, NCCL version 2.7.8
work = _default_pg.barrier()
RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:784, invalid usage, NCCL version 2.7.8
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 192, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/ceph-dan/.local/lib/python3.8/site-packages/torch/distributed/launch.py", line 260, in <module>
main()
File "/ceph-dan/.local/lib/python3.8/site-packages/torch/distributed/launch.py", line 255, in main
raise subprocess.CalledProcessError(returncode=process.returncode,
subprocess.CalledProcessError: Command '['/usr/bin/python3', '-u', './mmi_bigram_train.py', '--local_rank=1', '--world_size', '2']' returned non-zero exit status 1.
Hmm, I've never seen this one before...
I found something about why, when we use torch.distributed.launch, it was hanging at the end. (Caution: my lhotse was not fully up to date, although sampling.py doesn't seem to have changed in the interim). Firstly, at startup, I got a bunch of messages like this (I was confused with how many there were):
2021-04-10 23:35:58,870 INFO [mmi_bigram_train.py:274] About to create train dataset
2021-04-10 23:35:59,509 INFO [mmi_bigram_train.py:284] Using BucketingSampler.
2021-04-10 23:35:59,813 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,814 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,814 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,814 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,815 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,815 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,815 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,816 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,816 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,816 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,816 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,817 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,817 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local ran
...
2021-04-10 23:35:59,820 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,820 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,820 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,821 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,821 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,821 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1427)].
2021-04-10 23:35:59,822 INFO [sampling.py:523] Distributed training with world size of 2 detected (node's local rank is 0. Splitting cuts into 2 partitions (this partition has cut IDs range\
[(0, 1426)].
2021-04-10 23:35:59,824 INFO [mmi_bigram_train.py:298] About to create train dataloader
.. .that's just FYI. I am confused why it never says the local rank is 1, and why the last message says 1426. I am using world-size=2. The rank=1 process doesn't print much logs, but I does print some, which are printed near the start of the log file for some reason, and not at all in sync with where you'd expect them to be:
2021-04-10 23:36:00,287 INFO [mmi_bigram_train.py:401] epoch 0, learning rate 0.001
2021-04-10 23:36:02,585 INFO [mmi_bigram_train.py:176] batch 0, epoch 0/10 global average 2021-04-10 23:36:02,597 INFO [distributed.py:607] Reducer buckets have been rebuilt in this iterati\
on.
2021-04-11 00:06:37,689 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/best-epoch-info
2021-04-11 00:06:37,689 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/epoch-0-info
2021-04-11 00:06:37,690 INFO [mmi_bigram_train.py:401] epoch 1, learning rate 0.001
2021-04-11 00:37:08,277 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/best-epoch-info
2021-04-11 00:37:08,277 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/epoch-1-info
2021-04-11 00:37:08,278 INFO [mmi_bigram_train.py:401] epoch 2, learning rate 0.001
2021-04-11 01:07:37,757 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/best-epoch-info
2021-04-11 01:07:37,757 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/epoch-2-info
2021-04-11 01:07:37,758 INFO [mmi_bigram_train.py:401] epoch 3, learning rate 0.001
2021-04-11 01:38:04,478 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/best-epoch-info
2021-04-11 01:38:04,478 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/epoch-3-info
2021-04-11 01:38:04,478 INFO [mmi_bigram_train.py:401] epoch 4, learning rate 0.001
2021-04-11 02:08:40,364 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/best-epoch-info
2021-04-11 02:08:40,365 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/epoch-4-info
2021-04-11 02:08:40,365 INFO [mmi_bigram_train.py:401] epoch 5, learning rate 0.001
2021-04-11 02:39:10,325 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/best-epoch-info
2021-04-11 02:39:10,326 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/epoch-5-info
2021-04-11 02:39:10,326 INFO [mmi_bigram_train.py:401] epoch 6, learning rate 0.001
2021-04-11 03:09:50,414 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/best-epoch-info
2021-04-11 03:09:50,415 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/epoch-6-info
2021-04-11 03:09:50,415 INFO [mmi_bigram_train.py:401] epoch 7, learning rate 0.0008
2021-04-11 03:36:48,757 WARNING [cut.py:1023] To perform mix, energy must be non-zero and non-negative (got 0.0). Cut with id "c1c0349b-2b4f-3277-d8a3-c1a701cc4c32" will not be mixed in.
2021-04-11 03:40:24,571 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/epoch-7-info
2021-04-11 03:40:24,571 INFO [mmi_bigram_train.py:401] epoch 8, learning rate 0.0006400000000000002
2021-04-11 04:11:20,545 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/epoch-8-info
2021-04-11 04:11:20,546 INFO [mmi_bigram_train.py:401] epoch 9, learning rate 0.0005120000000000001
2021-04-11 04:42:13,339 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/best-epoch-info
2021-04-11 04:42:13,340 INFO [common.py:201] write training info to exp-lstm-adam-mmi-bigram-musan-dist/epoch-9-info
2021-04-11 04:42:13,340 WARNING [mmi_bigram_train.py:466] Done
erage objf: 1.353939 over 1285235.0 frames (100.0% kept), current batch average objf: 1.197140 over 9905 frames (100.0% kept) avg time waiting for batch 0.012s
2021-04-10 23:39:06,673 INFO [mmi_bigram_train.py:176] batch 140, epoch 0/10 global average objf: 1.346037 over 1383290.0 frames (100.0% kept), current batch average objf: 1.175396 over 974\
1 frames (100.0% kept) avg time waiting for batch 0.012s
Anyway, when I check the times of where it starts each epoch, this rank=1 job does not seem to be correctly synchronized with the rank=0 job. It finishes epoch 0 when the rank=0 job has only finished about 90% of its minibatches.
Wouldn't it be easier, in order to support distributed training, to just have the sampler process things as normal and then return 1 out of every world_size minibatches? The time that it processes that metadata will probably overlap with GPU stuff anyway, I don't really think that's going to be the limiting factor. That way we can more easily ensure the number of minibatches is exactly the same between jobs.
That way we can more easily ensure the number of minibatches is exactly the same between jobs.
I confirm that if different nodes have a different number of utterances in its dataloader, the node with most utterances will hang in the end.
I suspect the reason is due to allreduce
called inside backward
. Since nodes with fewer utterances
exit earlier, the node with most utterances waits inside allreduce
; but it does not receive responses from other already exited nodes. Therefore, it waits indefinitely.
A minimal example to reproduce it is given below.
The current approach to partition the dataset over different nodes cannot guranteen that every node receives the same amount of utterances. See https://github.com/lhotse-speech/lhotse/blob/3c6dea9e90536e01b6ce7b937682ab85ae50d680/lhotse/dataset/sampling.py#L519-L522
total = len(data_source)
per_partition = int(ceil(total / float(world_size)))
partition_start = rank * per_partition
partition_end = min(partition_start + per_partition, total)
The node with the largest rank value receives fewer if total % world_size != 0
.
That may explain why the node with rank==0 hangs in the end.
#!/usr/bin/env python3
import os
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import torch
import datetime
def run(rank: int, world_size: int):
print(f'world_size: {world_size}')
device = torch.device('cuda', rank)
if rank != 0:
data = [torch.tensor([1], device=device, dtype=torch.float32) for _ in range(world_size)]
else:
data = [torch.tensor([1], device=device, dtype=torch.float32) for _ in range(world_size*100)]
# NOTE: `data` on rank 0 has more entries
dist.barrier()
model = torch.nn.Linear(1, 1).to(device)
model = DDP(model, device_ids=[rank])
for i, d in enumerate(data):
model.zero_grad()
y = model(d)
y.backward()
print(f'rank {rank} done')
# node with rank==0 will exit after timeout (5 seconds)
# The default timeout is 30 minutes. But it comes into effect
# only if one of the following environment variables is
# set:
# - NCCL_ASYNC_ERROR_HANDLING
# - NCCL_BLOCKING_WAIT
# See https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
def init_process(rank: int, world_size: int, fn):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12357'
dist.init_process_group('nccl',
rank=rank,
world_size=world_size,
timeout=datetime.timedelta(0, 5))
fn(rank, world_size)
if __name__ == '__main__':
print(f'dist.is_available: {dist.is_available()}')
world_size = 3
processes = []
mp.set_start_method('spawn')
for rank in range(world_size):
p = mp.Process(target=init_process, args=(rank, world_size, run))
p.start()
processes.append(p)
for p in processes:
p.join()
Its output is:
(NOTE: It has to set the environment variable NCCL_ASYNC_ERROR_HANDLING=1
)
$ NCCL_ASYNC_ERROR_HANDLING=1 ./foo.py
dist.is_available: True
world_size: 3
world_size: 3
world_size: 3
rank 1 done
rank 2 done
Process Process-1:
Traceback (most recent call last):
File "/root/fangjun/open-source/pyenv/versions/3.8.6/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/root/fangjun/open-source/pyenv/versions/3.8.6/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/xxx/foo.py", line 47, in init_process
fn(rank, world_size)
File "/xxx/foo.py", line 28, in run
y.backward()
File "/root/fangjun/py38/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/root/fangjun/py38/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
Variable._execution_engine.run_backward(
RuntimeError: NCCL communicator was aborted.
@pzelasko How about dropping some utterances in partition_cut_ids so that all nodes get the same number of utterances?
Oooh, now it all finally makes sense. Thanks for debugging this guys. I'll add a fix to the cut ids partitioning in the sampler.
I'm going to use @danpovey's solution rather than @csukuangfj's solution -- unfortunately, it is not straightforward to estimate how many utterances should be dropped in partition_cut_ids
, since we have a dynamic batch size. That means that equal number of cuts for each partition doesn't guarantee equal number of batches due to variations in duration and max_duration
constraint. I think that dropping all batches except sth. like (batch_idx + rank) % world_size
is much more viable. I'll test it and make a PR in Lhotse.
@danpovey @csukuangfj can you please try out the version in PR https://github.com/lhotse-speech/lhotse/pull/267 and let me know if it helped? I won't be able to test the snowfall distributed training setup today, but based on the unit tests I wrote it seems to have fixed the issues with an unequal number of batches in each worker.
@pzelasko Thanks. Trying it.
@pzelasko I confirm that the current change can solve the hanging problem.
Here is the tensorboard log of DDP training with 3 GPUs:
And the WERs are
2021-04-13 09:12:14,242 INFO [common.py:356] [test-clean] %WER 7.39% [3885 / 52576, 505 ins, 330 del, 3050 sub ]
2021-04-13 09:14:29,223 INFO [common.py:356] [test-other] %WER 18.82% [9849 / 52343, 1149 ins, 863 del, 7837 sub ]
The WERs are worse than that of single GPU training. I believe the reason is due to the learning rate. You can compare the learning rate from the above tensorboard log with the one from single GPU training.
I believe if we train it for more epochs, it can achieve similar results.
NOTE: The training time per epoch with 3 GPUs is about 16 minutes, which is about 1/3 of single GPU training.
Great!! I think in order to get comparable results to the baseline we'd have to divide the minibatch size by the number of workers. Let's merge this?
On Tue, Apr 13, 2021 at 9:21 AM Fangjun Kuang @.***> wrote:
@pzelasko https://github.com/pzelasko I confirm that the current change can solve the hanging problem.
Here is the tensorboard log of DDP training with 3 GPUs:
And the WERs are
2021-04-13 09:12:14,242 INFO [common.py:356] [test-clean] %WER 7.39% [3885 / 52576, 505 ins, 330 del, 3050 sub ] 2021-04-13 09:14:29,223 INFO [common.py:356] [test-other] %WER 18.82% [9849 / 52343, 1149 ins, 863 del, 7837 sub ]
The WERs are worse than that of single GPU training. I believe the reason is due to the learning rate. You can compare the learning rate from the above tensorboard log with the one from single GPU training https://tensorboard.dev/experiment/h3xiWY0oQ4WGd2dRgG8NWw.
I believe if we train it for more epochs, it can achieve similar results.
NOTE: The training time per epoch with 3 GPUs is about 16 minutes, which is about 1/3 of single GPU training.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/snowfall/issues/152#issuecomment-818361538, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO3I4TUSDORV2AUBS6TTIOMBBANCNFSM42WSEJPA .
Oh it's a PR to lhotse, we'll wait for Piotr to merge.
Merged!
FYI this could be of interest to us https://huggingface.co/blog/accelerate-library
When I try to run more than one trainings (with a single job) on the same machine, I get this: