ksw0306 / FloWaveNet

A Pytorch implementation of "FloWaveNet: A Generative Flow for Raw Audio"
MIT License
490 stars 110 forks source link

Distributed Training with Apex #22

Closed 1ytic closed 5 years ago

1ytic commented 5 years ago

Apex utilities https://github.com/NVIDIA/apex handle some issues with specific nodes in the FloWaveNet architecture.

List of changes made in train.py:

  1. Determine local_rank and world_size for torch.distributed.init_process_group
  2. Set a current device with torch.cuda.set_device
  3. Wrap dataset with torch.utils.data.distributed.DistributedSampler
  4. Apply amp.scale_loss at each backward pass
  5. Clip gradient with amp.master_params
  6. Divide step_size by world_size (not sure if this is necessary)
  7. Initialize model and optimizer with amp.initialize
  8. Wrap model with apex.parallel.DistributedDataParallel
  9. Handle evaluation and messages on the first node using args.local_rank

For example, to run on 4 GPUs, use the following command: python -m torch.distributed.launch --nproc_per_node=4 train_apex.py --num_workers 2 --epochs 1000

Resolves: #13 See also: #16

L0SG commented 5 years ago

Huge thanks for apex & DistributedDataParallel integration (plus the nicer tqdm bar)! We also verified that the log-determinant changes throughout iterations using this (rather than DataParallel), so the current incompatibility issue seems specific to a reference count issue of DataParallel.