jramapuram / BYOL

Bootstrap Your Own Latent (BYOL) pytorch implementation using DistributedDataParallel.
MIT License
28 stars 2 forks source link

Running in half-precision #2

Open jlindsey15 opened 4 years ago

jlindsey15 commented 4 years ago

Hi! An earlier issue I was having (the one I posted in the wrong repo) seems to have resulted from my hacky attempts to get the half-precision functionality working with this code. Simply setting the "half" flag raises an error due to interactions with DistributedDataParallel (tells you that you need to call amp.initialize before wrapping in DistributedDataParallel). When I tried the obvious fix for this (moving amp.initialize earlier), another error arose, having to do with the "parameters_to_vector" calls. It seems that these were unable to deal with a mix of half-precision and full-precision parameters. My workaround was to no longer store all the parameters in one vector and instead to store and update them in a list, and while this got rid of all runtime errors it seemed to yield functional problems.

Rather than make you comb through the details of these changes, I figure I should just ask: what is the cleanest way to take the current code and make it runnable in half precision? Without doing so I can only run with batch sides of 64, even with 8 gpus, and I'd like to get up to batch size 256.

jramapuram commented 4 years ago

Yea sorry about that. I haven't tested this implementation in FP16, but have you tried swapping the optimization level for AMP to O2 ?

If this doesn't work, we would need to determine a way to rewrite the CosEMA module in a way that makes it compatible with how Apex handles its fp32/fp16 conversions and I'm not quite sure how to go about doing that. Something that can be tried also is forcible type-casting of the parameters via a.type(b.dtype) for each parameter.

jlindsey15 commented 4 years ago

O2 doesn't seem to resolve it -- for now, the simplest workaround in my particular case is probably to just use a smaller ResNet so I can increase the batch size. But if you ever come up with a good solution I'd be interested!