NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.37k stars 1.39k forks source link

Device mismatch when using AMP with Pytorch DataParallel #503

Open michaelklachko opened 5 years ago

michaelklachko commented 5 years ago

I'm running the following on 4 GPUs:

model = Resnet50()
model = model.cuda()
criterion = nn.CrossEntropyLoss(reduction='mean').cuda()
optimizer = torch.optim.SGD(model.parameters(), 0.001)
model, optimizer = amp.initialize(model, optimizer, opt_level='O3', keep_batchnorm_fp32=False)
model = torch.nn.DataParallel(model)

And I get the following error:

Selected optimization level O3:  Pure FP16 training.
Defaults for this optimization level are:
enabled                : True
opt_level              : O3
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : False
master_weights         : False
loss_scale             : 1.0
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O3
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : False
master_weights         : False
loss_scale             : 1.0
lr: 0.1 wd 0.0001
Traceback (most recent call last):
  File "main.py", line 559, in <module>
    main()
  File "main.py", line 555, in main
    train(train_loader, val_loader, model, criterion, optimizer, start_epoch, best_acc, args)
  File "main.py", line 409, in train
    output = model(input_var, epoch=epoch, i=i)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
    output.reraise()
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/_utils.py", line 369, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/apex/amp/_initialize.py", line 194, in new_fwd
    **applier(kwargs, input_caster))
  File "/home/michael/noisynet/models/resnet.py", line 161, in forward
    x = self.conv1(x)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 343, in forward
    return self.conv2d_forward(input, self.weight)
  File "/home/michael/miniconda2/envs/pt/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 340, in conv2d_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)
jianchao-li commented 5 years ago

This should be an expected behavior. The PyTorch DataParallel only works with O1. You may refer to #269 for more details.

vadimkantorov commented 4 years ago

I found the root cause: forward must be patched after DataParallel(...) call (because otherwise the patched method refers the old model object and not the dynamically created replica). Maybe some other patching way exists that would work fine with DP, but definitely not the straightforward way in https://github.com/NVIDIA/apex/blob/master/apex/amp/_initialize.py#L201

The workaround I found:

model = apex.amp.initialize(torch.nn.Sequential(model), opt_level = 'O2')[0]
model = torch.nn.DataParallel(model, device_ids = args.devices)
model.forward = lambda *args, old_fwd = model.forward, input_caster = lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options['cast_model_type']), output_caster = lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options['cast_model_outputs'] if apex.amp._amp_state.opt_properties.options.get('cast_model_outputs') is not None else torch.float32), **kwargs: apex.amp._initialize.applier(old_fwd(*apex.amp._initialize.applier(args, input_caster), **apex.amp._initialize.applier(kwargs, input_caster)), output_caster)

@jianchao-li

mcarilli commented 4 years ago

This is still very useful information and I haven't been ignoring it, but to be honest I'm probably not going to implement a fix in Apex soon. My absolute top priority right now is getting automatic mixed precision into Pytorch natively, which will eliminate all extension building/version matching issues. I'm taking care to ensure the native integration will support DistributedDataParallel, DataParallel, and model parallel usage. We are targeting the 1.5 release: https://github.com/pytorch/pytorch/issues/25081 Gradient scaling and autocasting will be independently-usable components. The gradient scaling PR is mature, awaiting final documentation review: https://github.com/pytorch/pytorch/pull/26512 The autocasting PR is about 3/4 done in terms of op coverage: https://github.com/pytorch/pytorch/pull/29552 Autocasting will likely be exposed via a context manager that can be used to locally enable/disable mixed precision for any desired regions of the model.

If you are having problems with the current incarnation of Apex, my best advice is to wait for the PRs to be merged. Getting native mixed precision support as soon as possible is the best path forward for everyone IMO.

vadimkantorov commented 3 years ago

@mcarilli Btw Is O2/O3 supported in PyTorch autocast? I had colleagues mentioned that they saw no RAM decrease when using PyTorch core autocast, as if activations were still stored in fp32