Novartis / torchsurv

Deep survival analysis made easy
https://opensource.nibr.com/torchsurv/
MIT License
47 stars 6 forks source link

Deep learning under log likelihood loss does not converge easily. #40

Open Minxiangliu opened 2 days ago

Minxiangliu commented 2 days ago

Hi here, Thank you for the tool, I'm currently trying to train 3D brain medical images, and I am using the DenseNet121 model and the neg_partial_log_likelihood loss function, and everything is normal, but when I observe the training and validation loss values, I find that the validation loss value cannot effectively converge, and there is even a trend of rising, and I have tried many different ways to reduce overfitting, including: data augmentation, dropout, over sampling, etc., the validation loss value is still not easy to converge, please ask this is a problem with the data itself, or have I overlooked some details?

A total of 189 3D images

image image

tcoroller commented 2 days ago

Dear @Minxiangliu , thank you for you question. At first glance, it may be an experimental issue (overfitting) rather than a TorchSurv loss function issue.

Couple comments that may help you

  1. Pick the right loss: when using cox model, the model optimize by ranking samples within a batch. The larger the batch, the more reliable the loss estimation will be. When using medical image however, it may be hard to have a batch size greater than 8 samples due to GPU memory limits. Two suggestions for you: You can instead try the Weibull model, which does not rank but fit a distribution for each sample. It removes batch size dependancies, which is good for you. If you want to keep cox model, then you should try our Momentum loss, which uses two networks (online and target) alongside a dynamic memory bank (like MoCo for contrastive learning). We have a tutorial notebook here.
  2. Data / Inputs: You are using 189 3D images to predict a time to event target. This may be under sampled for the task. Try to reduce the covariates size (e.g., reduce image dimension, use embeddings, ..). What is your target? How many censored/non-censored patients do you have? Is there any other covariates (e.g., clinical) you can use to help the model? Is there literature on the topic that shown signal?
  3. Modeling parameters: What is your training/validation split? What is your batch size? And your learning rate (the loss profile looks "jumpy)? Your training loss is decreasing close to 0, so its learning something but it is indeed very much overfitting. There are plenty of external ressources on overfitting that you can read and learn from.

Good luck for your project and thank you for using TorchSurv!