autonlab / auton-survival

Auton Survival - an open source package for Regression, Counterfactual Estimation, Evaluation and Phenotyping with Censored Time-to-Events
http://autonlab.github.io/auton-survival
MIT License
315 stars 74 forks source link

Increasing validation loss in RDSM with time-varying data #91

Open ryuzakizh opened 1 year ago

ryuzakizh commented 1 year ago

Hi auton-survival community @chiragnagpal @Jeanselme @chufangao @salvaRC , I appreciate your contribution to the time-varying survival analysis and thank you for making this library open to public.

I am having an issue with training RDSM using my own custom dataset. I made sure that my dataset looks like the one you used in jupyter notebook tutorials; my T column is the remaining time to event, and E is an event indicator. However, during training I have been continuously seeing decreasing training loss but increasing validation loss.

Then I tried to see what happens if I train with PBC dataset that you used in demo notebooks, and I noticed the same situation there; decreasing training loss and increasing validation loss.

I haven't made any changes to the methodology. Is this something intrinsic to RDSM or am I doing something wrong? These are the logs from the model

train_val_losses

Jeanselme commented 1 year ago

Hello, I have just dug into this, I am not able to reproduce the problem. Are you using the RDSM available in auton_survival/models/dsm? Do you have this issue when running examples/RDSM on PBC Dataset.ipynb ? Feel free to share your code if I can help with this

ryuzakizh commented 1 year ago

Hello @Jeanselme, thank you for the reply.

I am using RDSM version that supports GPU from Jeanselme/DeepSurvivalMachines, but I assume that RDSM version is the same as the one in auton_survival/models/dsm.

I am facing this issue when running both examples/RDSM on PBC Dataset.ipynb and when I apply RDSM to my own longitudinal data. I modified the code a bit to be able to see the training and validation losses, but I didn't change anything else. As you can see from the below, I am printing training and validation losses after every 10 iterations.

for i in tqdm(range(n_iter)):
    train_loss = []
    for j in range(nbatches):

      xb = x_train[j*bs:(j+1)*bs]
      tb = t_train[j*bs:(j+1)*bs]
      eb = e_train[j*bs:(j+1)*bs]

      if xb.shape[0] == 0:
        continue

      if cuda:
        xb, tb, eb = xb.cuda(), tb.cuda(), eb.cuda()

      optimizer.zero_grad()
      loss = 0
      for r in range(model.risks):
        loss += conditional_loss(model,
                                 xb,
                                 _reshape_tensor_with_nans(tb), #doesn't have nans in it
                                 _reshape_tensor_with_nans(eb),
                                 elbo=elbo,
                                 risk=str(r+1))
        train_loss.append(loss.detach().cpu().numpy())
      loss.backward()
      optimizer.step()
    if i%10==0:
      valid_loss = 0
      for r in range(model.risks):
        if cuda:
          x_valid, t_valid_, e_valid_ = x_valid.cuda(), t_valid_.cuda(), e_valid_.cuda()

        valid_loss += conditional_loss(model,
                                      x_valid,
                                      t_valid_,
                                      e_valid_,
                                      elbo=False,
                                      risk=str(r+1))

        valid_loss = valid_loss.detach().cpu().numpy()
        costs.append(float(valid_loss))
      print("Training loss is: ", np.mean(train_loss), " Validation loss is: ", valid_loss)
      dics.append(deepcopy(model.state_dict()))
      if costs[-1] >= oldcost: #if last validation loss is greater than or equal to old one
        if patience == 4: #if it starts to increase for more than 4 iterations then stop
          minm = np.argmin(costs)
          model.load_state_dict(dics[minm])

          del dics
          gc.collect()

          return model, i
        else:
          patience += 1
      else:
        patience = 0

      oldcost = costs[-1]

  minm = np.argmin(costs)
  model.load_state_dict(dics[minm])

  del dics
  gc.collect()

  return model, i

These are the model parameters:

model = DeepRecurrentSurvivalMachines(k = 3,
                                 distribution = 'LogNormal',
                                 hidden = 32, 
                                 typ = "GRU",
                                 layers = 1)
model.fit(x_train, t_train, e_train, iters = 10000, batch_size = 128, learning_rate = 1e-3)

I hope I was able to explain you. Thanks!

Jeanselme commented 1 year ago

Hello,

I have now tried both codes. I have updated the Jeanselme/ one as the normalization was not adapted for time series and was raising an error. I still do not observe the issue with the PBC dataset. Can you please check if the problem still occurs for you with the latest version? If so, please share with me (by email) the whole repo and I will dig into this.

ryuzakizh commented 1 year ago

Hi @Jeanselme , thanks a lot for the update.

I trained the latest version of RDSM on PBC dataset, and seems like it's fine. Validation loss decreases for the first 60 iterations and then starts to increase, but it stops shortly after that. Perhaps it starts to overfit the data since PBC itself is a small dataset.

I am still getting the issue when I use my own data though. But now I know that it might be a dataset problem, and it's not related to the methodology. Thank you!

Jeanselme commented 1 year ago

Hello,

Great to hear you no longer have the issue with PBC. The training should stop after a few iterations of increasing loss on the validation set and return the best model observed.

Concerning your dataset, are you normalizing the data?

ryuzakizh commented 1 year ago

Hello,

Great to hear you no longer have the issue with PBC. The training should stop after a few iterations of increasing loss on the validation set and return the best model observed.

Concerning your dataset, are you normalizing the data?

Thanks for the quick reply, I am normalizing only continuous variables. But I found out that in PBC all variables irrespective of the type are being normalized, so I decided to do the same and see what happens.

ryuzakizh commented 1 year ago

Hello @Jeanselme , I am still getting the same issue, I have been digging into this for a week. Could you please possibly look at it with fresh eyes if I send you the repository by email? I would very much appreciate it.

chiragnagpal commented 1 year ago

@ryuzakizh @Jeanselme Is this resolved yet? If not can I help facilitate this ?

ryuzakizh commented 1 year ago

Hello @chiragnagpal, no the issue hasn't been resolved. This was related to my thesis, so I ended up reporting the results that I got, and I still don't know what the problem was. However, we may consider this problem no longer relevant.