huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
14.64k stars 837 forks source link

Support Conv3D #795

Open yeruoforever opened 10 months ago

yeruoforever commented 10 months ago

3D convolutions are important too!

Like 2D images in nature, 3D Volume data is very common, such as medical CT and MRI data, where voxels are uniformly distributed in 3D space.

The medical and health field relies heavily on computer vision methods that have migrated from the field of 2D natural image processing, such as classification, segmentation and detection of lesions and organs. These methods also rely on convolutional neural networks, even in the Transformer of the VIT family, 3D convolution is an elegant and efficient implementation of 3DPatchEmbed.

Medical health is also a promising research direction of artificial intelligence, and has a bright application prospect. So, the 3D version of convolution should not be absent from our project.

janroden commented 9 months ago

I'd love that too :-)

Also for image registration (alignment) of 3D images (e.g. from MRI, CT, SPIM microscopy, ...), deep-learning based methods are getting better and more important (see e.g. this recent paper on brain MRI registration).

Having support for 3D images in Candle would enable us to implement such deep-learning based registration methods, and also "classical" optimization methods for 3D-image registration (see e.g. this library using PyTorch for classical optimization-based image registration).

This would bring the advantages of Candle to medical and biological image analysis, computer vision, etc., and could be used as a backbone to perform efficient (GPU) computation on 3D images in Rust in general.

Thanks a lot for this great library BTW!

EricLBuehler commented 9 months ago

Looks exciting, I wonder how complicated it would be to develop based on Conv2d?

EricLBuehler commented 9 months ago

@LaurentMazare, what are your thoughts on adding Conv3d to Candle - is it something that is planned to be added already? I would be happy to contribute an implementation.

LaurentMazare commented 9 months ago

I don't think it's planned yet so please go ahead if you want to give it a stab. What I would suggest is to start by implementing it as a custom op and together with a model that actually uses of it (not sure what is a good example of how these get used). Once you have some working model, we can decide between merging it within candle-core, or having it as a separate candle-conv3d crate based on the resulting code complexity + your experience implementing it. This way you will be able to iterate quickly on a separate repo + also try to the custom op layer quite a bit.

EricLBuehler commented 9 months ago

That sounds like a good idea - I think I will take a stab at it. Do you think I should implement a "toy", bespoke model trained on some 3d data or try to implement a more generic model (do you know of any that use conv3d)?

LaurentMazare commented 9 months ago

I feel that it would be much better if it's a real model. I wouldn't think that it's necessary to train it in candle but a model that has some already trained weights and we can run it on some data and check that it generates the same output as the python equivalent would be neat. Obviously it would be even better if it's a model that you find interesting/useful somehow.

janroden commented 9 months ago

Cool :) Thanks for wanting to give it a shot!

As an example, here's a relatively simple Keras model for 3D image classification from CT scans. Note that this also needs MaxPool3D layers in addition to the Conv3D layers though. Also, unfortunately, it needs a bit of preprocessing of the images. I found this crate for reading NIfTI images, but haven't used it yet.

Another example would be this more complicated, yet generally useful, PyTorch model for affine registration (alignment) of 3D images with coarse-to-fine vision transformer, from this recent paper (also see the references in this paper for previous CNN approaches for 3D-image registration). The GitHub repo also provides links to pre-trained model weights. Such 3D affine image registration is useful for a wide range of medical and biological image analysis for MRI, CT, and 3D-microscopy images.

This is another state-of-the-art transformer-based PyTorch implementation for image registration.