Open michaelklachko opened 5 years ago
This should be an expected behavior. The PyTorch DataParallel
only works with O1. You may refer to #269 for more details.
I found the root cause: forward must be patched after DataParallel(...) call (because otherwise the patched method refers the old model object and not the dynamically created replica). Maybe some other patching way exists that would work fine with DP, but definitely not the straightforward way in https://github.com/NVIDIA/apex/blob/master/apex/amp/_initialize.py#L201
The workaround I found:
model = apex.amp.initialize(torch.nn.Sequential(model), opt_level = 'O2')[0]
model = torch.nn.DataParallel(model, device_ids = args.devices)
model.forward = lambda *args, old_fwd = model.forward, input_caster = lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options['cast_model_type']), output_caster = lambda tensor: tensor.to(apex.amp._amp_state.opt_properties.options['cast_model_outputs'] if apex.amp._amp_state.opt_properties.options.get('cast_model_outputs') is not None else torch.float32), **kwargs: apex.amp._initialize.applier(old_fwd(*apex.amp._initialize.applier(args, input_caster), **apex.amp._initialize.applier(kwargs, input_caster)), output_caster)
@jianchao-li
This is still very useful information and I haven't been ignoring it, but to be honest I'm probably not going to implement a fix in Apex soon. My absolute top priority right now is getting automatic mixed precision into Pytorch natively, which will eliminate all extension building/version matching issues. I'm taking care to ensure the native integration will support DistributedDataParallel, DataParallel, and model parallel usage. We are targeting the 1.5 release: https://github.com/pytorch/pytorch/issues/25081 Gradient scaling and autocasting will be independently-usable components. The gradient scaling PR is mature, awaiting final documentation review: https://github.com/pytorch/pytorch/pull/26512 The autocasting PR is about 3/4 done in terms of op coverage: https://github.com/pytorch/pytorch/pull/29552 Autocasting will likely be exposed via a context manager that can be used to locally enable/disable mixed precision for any desired regions of the model.
If you are having problems with the current incarnation of Apex, my best advice is to wait for the PRs to be merged. Getting native mixed precision support as soon as possible is the best path forward for everyone IMO.
@mcarilli Btw Is O2/O3 supported in PyTorch autocast? I had colleagues mentioned that they saw no RAM decrease when using PyTorch core autocast, as if activations were still stored in fp32
I'm running the following on 4 GPUs:
And I get the following error: