yhygao / CBIM-Medical-Image-Segmentation

A PyTorch framework for medical image segmentation
Apache License 2.0
260 stars 46 forks source link

Save ensembled models #20

Closed dzzhang96 closed 1 year ago

dzzhang96 commented 1 year ago

Hi yhygao!

I have trained several folds using medformer and used ensemble modelling for prediction. But everytime it has to repeat the ensemble modelling for inference. Is it possible to save the ensemble model in future branch? Thanks!

yhygao commented 1 year ago

Hi,

Thank you for the advice. But I don't quite get your idea. What do you mean by saving the ensemble model? Suppose you have 5 models for inference, the prediction.py will make inference 5 times to generate 5 predicted maps from each model, and then average the prediction. Do you mean that fuse the 5 models into one model and only make inference for one time?

If my understanding is correct, I think it's a non-trivial problem. Although we can directly average the weights of 5 models, I doubt it will have the same performance with a 5-model ensemble. Using knowledge distillation might be a solution, but it's out of the scope of this repo. Please let me know if you have any better ideas.