havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
801 stars 188 forks source link

Thoughts about integration with pytorch-lightning? #60

Open rohanshad opened 3 years ago

rohanshad commented 3 years ago

Great work here havakv, looks really modular and well thought out.

I'm interested in using some of the time-to-event tools here with high dimensional imaging inputs. I've been building up a medical imaging codebase using pytorch-lightning for a little bit mostly because of how modular & convenient it makes iterating over multiple experiments on a cluster environment.

Do you have any ideas of how best to re-organize some of pycox.models (PCHazard for example) to torch-lightning before I start on this? What I'm really interested in is being able to use the pytorch lightning trainer.

havakv commented 3 years ago

Thank you for the kind words! I think pycox is way to dependent on torchtuples (which is quite limited) and this have been discussed in #25. Pycox could really benefit from working pytoch-lightning, but I don't think there needs to be any special integration, just a way to decouple pycox from torchtuples. One can then give examples of how to fit models with pytorch-lightning.

In relation to #25 I made this example in this branch where I propose a change to the LogisticHazard model so that it can can be fitted with just vanilla pytorch (no torchtuples stuff). Could you take a look at that and see if you're able to make it work with pytorch-lightning?

It would really help with some feedback on these changes before I start refactoring all the models :)

rohanshad commented 3 years ago

Perfect!

I'll take a crack at this soon (hopefully within this week) and circle back. Thanks 👍

rohanshad commented 3 years ago

Got it all working on a new conda environment, and I've successfully ported the example to torch lightning. I set up the dataset within a lightning DataModule and packaged the pre-processing functions there too. Since each model may require slightly different pre-processing steps, it might make sense to define all those preprocessing functions within the dataset module itself. I setup the data for train / test here too.

The model, train logic, metrics, loss, optimizer all fit in a surv_model LightningModule. The trainer function trains it all and spits out a progress bar, tracks experiment versions, and can dump logs to csv / tensorboard as required:

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name      | Type                 | Params
---------------------------------------------------
0 | net       | Sequential           | 1 K   
1 | loss_func | NLLLogistiHazardLoss | 0     
Epoch 19: 100%|███████| 6/6 [00:01<00:00,  4.23it/s, loss=2.269, v_num=34, loss_step=2.14, loss_epoch=2.26]
Running in Evaluation Mode...
Concordance: 0.6252826147316648

The only think that I keep vanilla pytorch is the testing phase since the metrics are calculated directly on a pandas dataframe obviating the need for a DataLoader. Let me know if you'd want me to open a PR on that branch so you can see what this looks like.

havakv commented 3 years ago

Great work @rohanshad! Sure you can open a PR! I'ts much simpler do discuss when we have some concrete examples.

rohanshad commented 3 years ago

66 Here you go ^

rohanshad commented 3 years ago

Let me know if you'd like to create an example for CoxPH, the workings and estimators seem to be a bit different from the logistic_hazards models. I can carry on from there and attempt to make a flexible-ish lightning module that works with coxPH too.

havakv commented 3 years ago

I think the Cox models will be a bit harder to make work (though CoxPH is likely the simplest). Currently computations of the non-parametric baseline hazards are part of the CoxPH class. Probably need to factor that out in a similar way as I did for the logistic-hazard. But you are of course more than welcome to give it a go!

Right now I have too much to do between work and revisjons, so I cant prioritise this (I imagine it's quite some work), but I'll get started as soon as I get the time.

havakv commented 3 years ago

I guess this can stay open until all of pycox can be use with pytroch-lightning

havakv commented 3 years ago

@rohanshad If you wan't to take a crack at CoxPH in pytorch-lighting, I've now made a refactored version of CoxPH in https://github.com/havakv/pycox/blob/refactor_out_torchtuples/pycox/models/coxph.py that should be straight forward to use. There's missing some docs and tests, but I'll add that later.

By using compute_cumulative_baseline_hazards and output2surv, I think it shouldn't be too much work. Let me know if you run into any issues!

yorickvanzweeden commented 3 years ago

@havakv Thank you for the refactorings of CoxPH and LogisticHazard. Do you have any plans to refactor PC-Hazard? Or should I be able to use them like the other two?

I am currently doing this. Yet, compared with the LogisticHazard and CoxPH, I am not getting great performance.

import pytorch_lightning as pl
import pandas as pd
import torch
import torch.nn.functional as F

from pycox.models.loss import nll_pc_hazard_loss
from pycox.evaluation import EvalSurv
from pycox.models.utils import pad_col, make_subgrid

class DummyModel(pl.LightningModule):
    def __init__(self,  duration_index=None):

        super().__init__()

        self.net = SomeModel
        self.loss_func = nll_pc_hazard_loss
        self.duration_index = duration_index

    def forward(self, x):
        return self.net.forward(x)

    def common_step(self, batch, batch_idx, stage):
        x, duration, event, interval = batch
        preds = self(x)
        loss = self.loss_func(preds, duration, event, interval)

        if stage == "train":
            return {"loss": loss}
        else:
            return {"loss": loss, "preds": preds, "event": event, "duration": duration}

    def training_step(self, batch, batch_idx):
        return self.common_step(batch, batch_idx, 'train')

    def training_epoch_end(self, outs):
        self.logger.experiment.add_scalar("loss/train", torch.mean(torch.stack([x['loss'] for x in outs])), current_epoch)

    def validation_step(self, batch, batch_idx):
        return self.common_step(batch, batch_idx, 'val')

    def validation_epoch_end(self, outs):
          self.logger.experiment.add_scalar("loss/val", torch.mean(torch.stack([x['loss'] for x in outs])), current_epoch)

          predictions = torch.vstack([x['preds'] for x in outs])
          durations = torch.vstack([x['duration'] for x in outs])
          events = torch.vstack([x['event'] for x in outs])

          surv_df = self.predict_surv_df(predictions, sub=10, duration_index=self.duration_index)
          ev = EvalSurv(surv_df, durations.cpu().numpy().reshape(-1, ), events.cpu().numpy().reshape(-1, ))
          self.logger.experiment.add_scalar("val_auroc", ev.concordance_td(), current_epoch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=SomeLearningRate)
        return optimizer

    def predict_surv_df(self, preds, sub, duration_index):
        n = preds.shape[0]
        hazard = F.softplus(preds).view(-1, 1).repeat(1, sub).view(n, -1).div(sub)
        hazard = pad_col(hazard, where='start')
        surv = hazard.cumsum(1).mul(-1).exp()
        surv = surv.cpu().numpy()

        index = None
        if duration_index is not None:
            index = make_subgrid(duration_index, sub)
        return pd.DataFrame(surv.transpose(), index)

The DataLoader is aligned with the PC-Hazard notebook. In this way, the duration_index corresponds to PCHazard.label_transform(num_durations).cuts

havakv commented 3 years ago

Hi @yorickvanzweeden. I should refactor PCHazard in the same way as for LogisticHazard (just struggling to find the time).

From what I can see, your code should work, so I don't really know why you're not getting the results you want. Have you tried comparing these results with

model = PCHazard(net, optimizer, duration_index=labtrans.cuts)
model.fit(...)
mode.predict_surv_df(...)

to check if it is just the PCHazard that doesn't perform well, or if there is something with you implementation?

yorickvanzweeden commented 3 years ago

Thanks @havakv for your reply. I suspect it is due to the difficulty of the problem in combination with hyperparameters that have yet to be optimized.