mehta-lab / microDL

3D virtual staining with 2D and 2.5D U-Nets
BSD 3-Clause "New" or "Revised" License
27 stars 7 forks source link

Convert existing tensorflow models to pytorch #208

Closed Soorya19Pradeep closed 1 year ago

Soorya19Pradeep commented 1 year ago

As we convert the microDL pipeline to work with pytorch we need models for the infected cell project to predict nucleus and membrane in HEK cells. It will be efficient to convert the existing HEK nucleus and membrane prediction models @JohannaRahm trained to work with pytorch rather than retraining them. We can convert these models using onnx from tensorflow to pytorch. @ziw-liu, do you have any suggestions or pointers?

mattersoflight commented 1 year ago

@Soorya19Pradeep @ziw-liu I suggest you try two paths:

Convert working tensorflow models to a deployable format:

(All of the following links are to the release 1.0.0 commit, which we have tested extensively):

As @Soorya19Pradeep pointed out, current tensorflow implementation saves only weights and not the architecture. But, there is a load_model method in inference module, which can create a model based on config and weights. It is used in training method.

I have not used this feature before, but @jennyfolkesson may remember if resuming the training from saved model ever worked. If it did, the load_model method can be adapted to export the model such that it can be deployed with pytorch.

Interestingly, this project on model management and conversion tools that may be handy: https://github.com/Microsoft/MMdnn

It will be great to reuse these models.

@ziw-liu can drive this. @Soorya19Pradeep please note paths of 2 tensorflow models whose inferences you recently used that @ziw-liu can work with.

Train models with pytorch pipeline and existing data:

@Soorya19Pradeep and @Christianfoley can drive this, and we can discuss this on separate issue.

Soorya19Pradeep commented 1 year ago

@ziw-liu , you can use these model weights generated and saved by @JohannaRahm :

/hpc/projects/CompMicro/projects/virtualstaining/2022_microDL_nuc_mem/models/2022_03_15_nuc_mem/loss_functions/heavy_augmentation_z25-60_mae/Model_2022-09-14-10-18-31.hdf5

/hpc/projects/CompMicro/projects/virtualstaining/2022_microDL_nuc_mem/models/2022_03_15_nuc_mem/loss_functions/heavy_augmentation_z12-74_mae/Model_2022-05-11-13-46-08.hdf5

The model was trained to predict membrane and nucleus from phase in HEK cells.

Soorya19Pradeep commented 1 year ago

On discussion with @mattersoflight and @ziw-liu it was decided we will not be porting the weights from models trained on tensorflow version of microDL to pytorch or onnx. This was primarily required to compare the quality of prediction from the tensorflow model to the pytorch model.

We will instead produce the tensorflow model predictions required for comparison on fry2.