Closed konstin closed 3 years ago
The latter is a great improvement, something that may be relevant for many people working on clusters (with maybe less RAM assigned to a job than vRAM to the GPU). The former might be an upstream problem I'd not suggest you to lose too much time to. From my (very limited) testing it seems that the model in .half() is not parallelized, and is thus slower. It's not slower because it uses different fp, but simply becasue there's some parallelization issue at some point of the tree.
Regardless: thanks :)
I think we can close this for now as it's handled.
See https://github.com/pytorch/pytorch/issues/48245#issuecomment-730967335. I'll also check for a rule of thumb for RAM usage so we don't overflow people's RAM when falling back the the CPU.