Open mbelouso opened 1 month ago
Hey Matt,
Unfortunately we are presently being blocked on this by the implementation (or lack thereof) of 16bit floats in the torch
package used for neural net models — a good rundown is here:
https://github.com/pytorch/pytorch/issues/70664
In particular:
Supports torch.half on CUDA with GPU Architecture SM53 or greater. However it only supports powers of 2 signal length in every transformed dimensions.
Although powers of 2 are quite common in our use case as we often downsample particle stacks to 128x128 or 256x256, there is not much we can do about the lack of support for older architectures. On our own compute cluster, we thus see RuntimeError: Unsupported dtype Half
errors when using the float16
format even with PyTorch v2.4.
However, we can look into some kind of hackier solution in the meantime, such as casting back and forth from float16
and float32
just for these torch
operations, or automatically converting to float32
if float16
format is encountered. Also note that we do currently use torch.cuda.amp, which thus exploits the benefits of using lower precision when possible even when the original input is float32
.
I will try to have some kind of preliminary solution to this in an upcoming release for the sake of smoother interoperability with RELION!
Best, Mike
also, as @ryanfeathers has pointed out, we already cast to float32
in the case of .star files pointing to float16
.mrcs files, so we should just extend this behaviour to all float
inputs with suitable warning messages!
This is also important for interoperability with cryoSPARC as well as RELION - float16 output for motion correction and extraction has been available for the last couple of versions, and we use it for everything due to the substantial reduction in disk space requirements
I added a patch to cast float16
inputs to float32
when needed for reconstruction methods:
https://github.com/ml-struct-bio/cryodrgn/blob/7cb8f78cc23b1601c47c8108e40e44e5bdc90632/cryodrgn/fft.py#L31-L37
This allows commands such as train_vae
, abinit_homo
, etc. to accept float16
.mrcs
files as input, in addition to us already converting float16
formats to float32
in the case of .star files as noted above. After doing some tests, I see that runtimes are the same across both formats — so the casting isn't costing us anything substantial, but reading and performing (other) machine learning operations using these formats does not result in any gains either.
These changes will be made part of the v3.4.1 release and can be accessed now through our beta release channel:
pip install -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ 'cryodrgn<=3.4.1' --pre
Please let us know if you run into any more problems with float16
input!
Hey guys,
As we write most of our particle stacks out in RELION, which allows for 16bit float mrcs (MRC mode 12), its a bit of pain converting it back to 32-bit float (MRC mode 0) for ingest into cryodrgn. What are the chances of the addition of MRC mode 12 support?
cheers
matt B