havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
803 stars 188 forks source link

Customized evaluation metric or loss function for imbalanced dataset #49

Closed Karenou closed 4 years ago

Karenou commented 4 years ago

Hi, I have encountered another problem, which is specific to the dataset I am using. Since the dataset is very imbalanced, I undersample the negative class (censoring data with event = 0) in the training set but not the validation set. But when I use the negative log-likelihood loss, eg. NLLMTLRLoss, the validation loss is far less than the training loss. Is there another way to add customized evaluation metrics in model.fit for model training aside from the negative log-likelihood? Or I just write another _Loss class considering the class weight?

Many thanks!

havakv commented 4 years ago

So, I guess the simples way to do this would be to include your desired metric using model.fit(..., metrics={'your_metric': your_metric}), as described here https://nbviewer.jupyter.org/github/havakv/torchtuples/blob/master/examples/02_general_usage.ipynb#Trainig

As most of pycox follows the implementation of torchtuples you can find some information in the examples there.

Karenou commented 4 years ago

Thanks for your help!

Karenou commented 4 years ago

Hi, may I know if I could use concordance-index as the evaluation metrics in model.fit? Thanks!

havakv commented 4 years ago

Happy to help.

The answer to whether you can use the concordance index as an evaluation metric is both yes and no. If you really want to use the concordance for this purpose, I would suggest another approach, but first, let me state the reasons why it is not straight forward to use the concordance in this manner.

My suggested approach is to instead write a callback that computes the concordance. You can either write your own callback from scratch by inheriting from tt.cb.Callback or you can use the tt.cb.MonitorMetrics base class for some added functionality:

class Concordance(tt.cb.MonitorMetrics):
    def __init__(self, x, durations, events, per_epoch=1, verbose=True):
        super().__init__(per_epoch)
        self.x = x
        self.durations = durations
        self.events = events
        self.verbose = verbose

    def on_epoch_end(self):
        super().on_epoch_end()
        if self.epoch % self.per_epoch == 0:
            surv = self.model.interpolate(10).predict_surv_df(self.x)
            ev = EvalSurv(surv, self.durations, self.events)
            concordance = ev.concordance_td()
            self.append_score('concordance', concordance)
            if self.verbose:
                print('concordance:', concordance)

Now, in the MTLR example notebook you can simply add the following to compute (and print) the concordance of the test set every 5th epoch:

epochs = 100
callbacks = [tt.callbacks.EarlyStopping(), Concordance(x_test, durations_test, events_test, per_epoch=5)]
log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val)

If you want to plot the concordance progress, you can get the scores as a pandas dataframe, and plot that:

test_concordance = model.callbacks['Concordance'].to_pandas()
test_concordance.plot()
havakv commented 4 years ago

Btw, if you want to do early stopping base on the concordance criteria, I would suggest supplying the EarlyStopping callback with argument get_score=get_test_concordance, where get_test_concordance() returns the concordance of the test set. This might be a bit unintuitive so just let me know if you want me to write a short example of this. You can also take a look at the EarlyStopping code

crcastillo commented 4 years ago

That would be much appreciated and would come in handy with optimizing DeepHitSingle as the loss function's relation to the concordance_td metric is dependent on the alpha and sigma parameters

Karenou commented 4 years ago

Thank you for the example code. It works well on my dataset.

I also have another question about LRScheduler class in torchtuple.callback.py. The class requires two input - torch.optim.lr_scheduler object and Monitor object. May I know if there is any explanation on the Monitor object? As I did not see its implementation in the codes. Many thanks!

I saw that the tt.optim.AdamWR and LRCosineAnnealing in callback.py enable to adjust the learning_rate per batch. May I know if there are any methods to schedule the lr_decay by epoch?

havakv commented 4 years ago

@crcastillo Here is one way to do early stopping base on concordance:

Continuing with the Concordance callback above, we add the function get_last_score to get the latest concordance score

class Concordance(tt.cb.MonitorMetrics):
    def __init__(self, x, durations, events, per_epoch=1, verbose=True):
        super().__init__(per_epoch)
        self.x = x
        self.durations = durations
        self.events = events
        self.verbose = verbose

    def on_epoch_end(self):
        super().on_epoch_end()
        if self.epoch % self.per_epoch == 0:
            surv = self.model.interpolate(10).predict_surv_df(self.x)
            ev = EvalSurv(surv, self.durations, self.events)
            concordance = ev.concordance_td()
            self.append_score('concordance', concordance)
            if self.verbose:
                print('concordance:', concordance)

    def get_last_score(self):
        return self.scores['concordance']['score'][-1]

Now, we can tell the EarlyStopping callback to use this new function for early stopping

concordance = Concordance(x_test, durations_test, events_test)
early_stopping =  tt.callbacks.EarlyStopping(get_score=concordance.get_last_score,
                                             minimize=False)
callbacks = [concordance, early_stopping]
log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val)

Two things to note here:

An alternative to this would be to inherit from the EarlyStopping callback, and implement the logic of the concordance there. Meaning you create a single class EarlyStoppingConcordance(tt.callbacks.EarlyStopping).

havakv commented 4 years ago

@Karenou So, it the LRScheduler seems to be an old piece of code that is not working anymore. Feel free to open an issue in the torchtuples repo about this.

If you want a quickfix, something along the lines of this should work

class LRScheduler(Callback):
    '''Wrapper for pytorch.optim.lr_scheduler objects.
    Parameters:
        scheduler: A pytorch.optim.lr_scheduler object.
       get_score: Function that returns a score when called.
    '''
    def __init__(self, scheduler, get_score):
        self.scheduler = scheduler
        self.get_score = get_score

    def on_epoch_end(self):
        score = self.get_score()
        self.scheduler.step(score)
        stop_signal = False
        return stop_signal

where get_score is a function such as that of the concordance example above. If you want the use the training loss as a score for you scheduler, you can instead use self.model.train_metrics.scores['loss']['score'][-1] and if you want to use the validation loss you can use self.model.val_metrics.scores['loss']['score'][-1].

So as you can see, it is often simpler to just write the Callback you want to use, rather than relying on the existing implementations.

havakv commented 4 years ago

@Karenou as for tt.optim.AdamWR and LRCosineAnnealing there are currently no implementations that schedule the learning rate by epoch instead of by batch. However, in AdamWR you can just set a higher number for cycle_len which should represent roughly the same thing. E.g., if you want one cycle to take 50 epoch, just set cycle_len=50.

If you really only want to change the learning rate at each epoch, you can probably use the functionality available in pytorch for this, and call the decay the learning rate either in a callback or rewrite the training loop in model.fit_dataloader.

Also note that you don't need to use the optimizers in torchtupls. You can just use regular pytorch optimizers instead. torchtuples essentially wraps some of the pytorch object to remove some boilerplate code.

havakv commented 4 years ago

In many ways I think the reimplementation of existing pytorch optimizers was a bad choice for torchtuples as it will require a lot to maintain this. A better approach would probably have been to instead rely more on existing pytorch objects.

Anyways, let me know if this was what you were looking for

havakv commented 4 years ago

Again, @Karenou, if you have problems understanding or implementing what I'm writing, don't hesitate to ask for an example. Writing custom callbacks requires you to be familiar with regular pytorch, as it is essentially just writing the pytorch training loop. torchtuples has not simplified using the features of pytorch, just removed som boilerplate code.

crcastillo commented 4 years ago

@havakv This is fantastic. Thank you!

Karenou commented 4 years ago

@havakv, Thanks for your help. May I know if you could give an example of how to add the get_last_score function? I have tried the following codes but it doesn't work. It said that MonitorFitMetrics object is not subscriptable. It may due to my wrong way of adding the get_last_score function in the LRScheduler class.

class LRScheduler(Callback):
    '''Wrapper for pytorch.optim.lr_scheduler objects.
    Parameters:
        scheduler: A pytorch.optim.lr_scheduler object.
       get_score: Function that returns a score when called.
    '''
    def __init__(self, scheduler):
        self.scheduler = scheduler

    def on_epoch_end(self):
        score = self.get_last_score()
        self.scheduler.step(score)
        stop_signal = False
        return stop_signal

    def get_last_score(self):
        return self.model.val_metrics['loss']['score'][-1]

from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, ReduceLROnPlateau

lambda1 = lambda epoch: 0.9 ** (epoch // 50)

lr_scheduler = LambdaLR(torch.optim.Adam(model.net.parameters(), lr=params["lr"]), lr_lambda=[lambda1])

callbacks = [LRScheduler(lr_scheduler)]

log = model.fit(x_train, y_train, params["batch_size"], epochs=1, callbacks=callbacks, num_workers=5,
                        val_data=val, val_batch_size=8224)
havakv commented 4 years ago

@Karenou it should just be self.model.val_metrics.scores['loss']['score'][-1], then it should work. Sorry for giving you the wrong expression before.

In the same way, it should be self.model.train_metrics.scores['loss']['score'][-1] if you want to access the last training loss.

havakv commented 4 years ago

However, there are a couple of other issues with your code:

  1. The model and scheduler use two different optimizers, so the scheduler updates the wrong optimizers learning rate. Instead you should use this
    optimizer = torch.optim.Adam(net.parameters(), lr=params["lr"])
    model = MTLR(net, optimizer, duration_index=labtrans.cuts)
    lr_scheduler = LambdaLR(optimizer, lr_lambda=[lambda1])
    callbacks = [LRScheduler(lr_scheduler)]
    log = model.fit(x_train, y_train, params["batch_size"], epochs=1, callbacks=callbacks, num_workers=5,
                val_data=val, val_batch_size=8224)
  2. The schedulers LambdaLR, StepLR and MultiStepLR all use the epoch number (which they keep track of themselves) and does not use the score. So when you pass the validation loss to LambdaLR it doesn't really update the learning rates. Instead you could use

    class LRScheduler(tt.cb.Callback):
    def __init__(self, scheduler):
        self.scheduler = scheduler
    
    def on_epoch_end(self):
        self.scheduler.step()
        stop_signal = False
        return stop_signal

    But if you plan on using this callback with the ReduceLROnPlateau, you should keep your callback as it is.

Karenou commented 4 years ago

@havakv Thank you so much! It works well now.

Karenou commented 4 years ago

@havakv Given the existing implementation of negative log-likelihood loss function for MTLR (or other models that parameterizes the PMF function), may I know if it is reasonable to include additional regularization terms in the loss function? Or should I perform regularization through other methods such as dropout layers? Many thanks!

havakv commented 4 years ago

@Karenou to my understanding, it is not that common to include regularisation terms to the loss functions of neural networks. If you want an L2 penalizer (like in ridge regression), you can achieve (essentially) the same by including wight decay in the optimizer (for vanilla SGD weight decay is in fact equivalent to L2). This is typically specified in the docs of each torch optimizer (e.g., Adam https://pytorch.org/docs/stable/optim.html#torch.optim.Adam).

L1 penalties (lasso) are not that common as a sparse set of parameters does not make as much sense in a networks as it does for a linear model.

Dropout is a very reasonable approach to regularisation, and personally I've had the most success with this.

This is just my take on it, and I'm sure you would get different answers from other people.