Closed ibro45 closed 4 years ago
Good find! Thanks for the detailed example, I'll have a look at it on Monday. It appears to be a straightforward fix.
@ibro45 This issue is a lot trickier than I initially anticipated. The problem with this issue is very similar to using torch.cuda.amp
with checkpointing, see here:
https://github.com/pytorch/pytorch/issues/37730
At the moment checkpointing and torch.cuda.amp
don't work together nicely and since the memcnn.InvertibleModuleWrapper
works in similar ways to checkpointing it suffers from similar issues.
Sadly some of the proposed solutions in the abovementioned ticket do not work for the memcnn.InvertibleModuleWrapper
since it is even a bit more complicated than checkpointing.
In the end, I was able to get your example code working by:
1) enforcing float32. I did this by adding @custom_fwd(cast_inputs=torch.float32)
and @custom_bwd
before the forward
and backward
methods in the InvertibleCheckpointFunction
respectively. (I would recommend this for now)
2) enforcing float16 model weights. model.half()
which converts all model weights to float16 and by ditching the GradScaler (doesn't like scaling float16 gradients and will complain). This will save even more memory, but I can't make any claims about the quality of the gradients or the results from the training process.
Both solutions are not really satisfactory at the moment. Here is a link to the commit implementing 1: https://github.com/silvandeleemput/memcnn/commit/881fd35df0d6bd6214810e3170107da884a98f49
Check test_amp.py for how I used this: https://github.com/silvandeleemput/memcnn/commit/881fd35df0d6bd6214810e3170107da884a98f49#diff-30ad4023a6eeedb230c16ee579c7a86b7c738fa5bd6ab34f939386d7969c2568
@ibro45 I have chosen to accept solution 1 as the current best "fix" for mixed-precision training. I am currently pushing a new release (1.5.0) for memcnn that includes the fix. As recommended by you I also included a note describing the limitations of mixed-precision training (float32 inputs) to the InvertibleModuleWrapper. I hope this solves your problem.
Thanks a lot for the fix and sorry for getting back to you earlier, got sidetracked.
Is the current solution not the best fix because it requires scaler.backward()
inside of autocast
or that's not the only reason? Also, is it necessary that the backward()
is under autocast
and, if so, do you know why?
I'm currently running it with backward()
outside of autocast
in my code, there are no errors, but is it safe to do that?
@ibro45
Thanks a lot for the fix and sorry for getting back to you earlier, got sidetracked.
No problem.
Is the current solution not the best fix because it requires
scaler.backward()
inside ofautocast
or that's not the only reason? Also, is it necessary that thebackward()
is underautocast
and, if so, do you know why?
scaler.backward()
outside of autocast
section, I did this mainly to support some tests with checkpointing. I'm currently running it with
backward()
outside ofautocast
in my code, there are no errors, but is it safe to do that?
Yes, I think it should be.
Description
When using the the native Pytorch Automatic Mixed Precision:
torch.cuda.amp.autocast
.float16
, while the weights arefloat32
. It is a simple fix with a decorated specifying that all the operations in it should be done infloat32
, as mentioned hereWhat I Did
Minimal example:
USE_AMP
defines if AMP will be usedMANUALLY_CAST_INPUT_TO_INV_BLOCK
toTrue
to fixes the error by casting the input to invertible block manually