pytorch / examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
https://pytorch.org/examples
BSD 3-Clause "New" or "Revised" License
22.32k stars 9.53k forks source link

fp16 (half precision) training doesn't work with 2 or more GPU's #311

Open tstandley opened 6 years ago

tstandley commented 6 years ago

For instance when I use the code from @csarofeen 's fp16 example, everything works fine on 1 gpu for both --fp16 and regular 32 bit training. On 2 gpu's, 32 bit training still works fine, but 16 bit training broken.

Training become unstable or results in slower learning curves. Also, validation loss is often NaN.

Tested with several setups including 1 and 2 titan V's with cuda 9.1 and 390.xx and 9.0 on 384.xx

I tried adding
torch.cuda.synchronize() around the special lines for fp16 as well as casting the output half back to a float before sending it into the criterion. No luck with either idea.

Any help would be appreciated.

csarofeen commented 6 years ago

Could you please walk me through what exactly you're running? The example I have up is https://github.com/csarofeen/examples/tree/dist_fp16/imagenet I would recommend running with python -m multiproc main.py -a resnet50 -b 256 -j 5 -p 10 --fp16 --dist-backend nccl

tstandley commented 6 years ago

When I run this, it works: (16 bit, 1 gpu) CUDA_VISIBLE_DEVICES=0 python3 main.py /data/ilsvrc -a resnet50 -b 128 --fp16 --loss-scale=256 --lr=.01

This works too: (32 bit, 2 gpus) CUDA_VISIBLE_DEVICES=0,1 python3 main.py /data/ilsvrc -a resnet50 -b 128 --lr=.01

This is broken: (16 bit, 2 gpus) CUDA_VISIBLE_DEVICES=0,1 python3 imain.py /data/ilsvrc -a resnet50 -b 128 --fp16 --loss-scale=256 --lr=.01

It converges much more slowly and I get high loss. If I run the same 3 examples with a higher learning rate, again the first two work, but the last one now diverges.

My understanding is that --dist-backend is not functional for fp16. It's not what I want either, my GPU's are both local, so I want DataParallel

Thanks!

csarofeen commented 6 years ago

--dist-backend nccl is the recommended method for fp16. Why --lr 0.01 instead of the default?

csarofeen commented 6 years ago

For nccl distributed you need to build from source with nccl version 2.1.2 installed locally. You'll also need to make sure the build picks up this version of nccl. In fact when you don't run it as the README specifies, it does not run multi-gpu.

csarofeen commented 6 years ago

Also, where did you get loss-scale of 256? This example doesn't need any loss scale, it's only included as a demonstration of using loss scale when needed.

tstandley commented 6 years ago

The learning rate and the scale factor were attempts to make training on both GPU's with fp16 stable, and this works somewhat.

What is dist-backend? What is the difference between nccl and the default (gloo)? Why doesn't gloo work with fp16? Are you saying I need to recompile pytorch to use nccl? I don't have root privileges on the machine I'm running on, so I'm not sure if I can do that.

Thanks, Trevor

tstandley commented 6 years ago

It also occurs to me that dist-backend has no effect because world_size is not changed from the default of 1. I'm not trying to train on multiple separate computers. Things aren't working with a single computer that has 2 gpus in it.

csarofeen commented 6 years ago

gloo is very slow with fp16 communication you need NCCL to get good perf. The point of python -m multiproc is to fill world-size and rank automatically. Distributed is also intended for single computer multi-gpu runs as well. Loss scale for resnet is not needed for final convergence. We've converged it many times without any loss scaling.

tstandley commented 6 years ago

So are you saying that fp16 training doesn't work on multiple GPU's with torch.nn.DataParallel() it only works with torch.nn.parallel.DistributedDataParallel()?

I know that I don't need the loss scale. I just put it because I was messing with the parameters to try to get convergence on 2 gpus. On 1 gpu, it converges without it. In other words, the loss scale is not actually related to my problem.

tstandley commented 6 years ago

Also, can we clarify which version of your imagenet example we are talking about? I was basing my question on the branch "fp16_examples_cuDNN_ATen"

Thanks again.

csarofeen commented 6 years ago

Training in mixed-precision is a bit more complicated than throwing a .half() on the model and input. That's why I wrote https://github.com/csarofeen/examples/tree/dist_fp16/imagenet

Mixed precision training can work in DataParallel but I do not have it in my example. Mainly because it doesn't have good scaling on multi-gpu. My recommendation is get NCCL 2.1.2, build PyTorch from source with it, and use my example as written. If you need to modify it try to understand what exactly is being done in terms of a master_parameter copy in fp32, and try to be consistent with the example. If you're coming to GTC next month, I will a few mixed-precision presentations including a short tutorial that I will give on pytorch.

csarofeen commented 6 years ago

As the default branch and all the links I've been posting, the correct branch is dist_fp16

tstandley commented 6 years ago

I'm not just throwing .half on my models and inputs. I'm using the procedure outlined in http://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html as you coded in https://github.com/csarofeen/examples/tree/fp16_examples_cuDNN-ATen/imagenet . This works, but only for a single GPU. The only problems happen when I use 2 gpus AND fp16.

Unfortunately I can't use NCCL because I don't have the ability to recompile and install pytorch (no root). For what it's worth, scaling isn't a problem. When I do run both GPUs the speed is about 2x as fast, it just doesn't converge.

I just wish I knew what was going on. Like what mathematically is happening such that it diverges. And why would reducing the learning rate prevent it from diverging, but result in poor model performance? While the same code works perfectly on a single GPU or on two gpus with fp32. Something definitely seems like a bug.

tstandley commented 6 years ago

Looking into this more, the problem is with torch.nn.parallel.replicate() It only replicates an fp16 model to the first gpu. The other gpus don't get the correct weights.

As a partial workaround, you can keep your model an fp32 model but then run:

def new_replicate(self, module, device_ids):
                    replicas = torch.nn.parallel.replicate(module,device_ids)
                    replicas = [convert_to_half(r) for r in replicas]
                    return replicas

torch.nn.DataParallel.replicate=new_replicate

(where convert_to_half converts everything in the model to half except for batch norm layers)

Though this breaks eval mode.

Seems like a bug with pytorch that should get fixed.

tstandley commented 6 years ago

@csarofeen, dist_fp16/imagenet doesn't work either. With the gloo or tcp backends, I get: AttributeError: module 'torch.distributed' has no attribute '_backend With the nccl backend, I get: RuntimeError: _Map_base::at Which might be what you get when you don't correctly have nccl installed. In any event it's a confusing error message.

BTW, is there a tutorial for correctly compiling pytorch with nccl? Am I supposed to use nccl 2?

@soumith Will the next pytorch release be compiled with nccl? Thanks

teng-li commented 6 years ago

@tstandley yes, 0.4 has NCCL backend included

yaceben commented 6 years ago

Hi, just wanted to chime in that you don't need root to install/compile. You can check out https://www.osc.edu/resources/getting_started/howto/howto_add_python_packages_using_the_conda_package_manager for more information.