havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
781 stars 180 forks source link

Calculating the gradients and loss of model #94

Closed ShadiRahimian closed 2 years ago

ShadiRahimian commented 2 years ago

Hi,

Thanks for the wonderful package! I was wondering how I can access the gradient values during the training of the models (e.g. DeepHitSingle) and how I can calculate the loss for a given input. I am not familiar with torchtuples and the training seems to be coupled with this package.

Many thanks

havakv commented 2 years ago

Hi, and thank you for the kind words.

There are two approaches to this. You could create code that fits DeepHit without depending on torchtuples (which would require you to read though and get the relevant parts form pycox and torchtuples), or you could make a torchtuples Callback. The latter is probably simpler in this case, so here is how it works:

The training loop looks like this (code here)

stop = self.callbacks.on_fit_start()
for _ in range(epochs):
    if stop: break
    stop = self.callbacks.on_epoch_start()
    if stop: break
    for data in dataloader:
        stop = self.callbacks.on_batch_start()
        if stop: break
        self.optimizer.zero_grad()
        self.batch_metrics = self.compute_metrics(data, self.metrics)
        self.batch_loss = self.batch_metrics['loss']
        self.batch_loss.backward()
        stop = self.callbacks.before_step()
        if stop: break
        self.optimizer.step()
        stop = self.callbacks.on_batch_end()
        if stop: break
    else:
        stop = self.callbacks.on_epoch_end()
self.callbacks.on_fit_end()

meaning is calls methods on callbacks (a collection of callbacks) in different part of the loop. Each callback has a model attribute that refers to the DeepHit model, so you can use e.g., model.batch_loss to access the batch_loss above.

You can create your own callback and implement the methods where you want it to be called, so if you want gradients at before_step you can make the callback by altering this example:

class Gradients(tt.cb.Callback):
    def before_step(self) -> None:
        grads = []
        for param in self.model.net.parameters():
            grads.append(param.grad.view(-1))
        self.grads = torch.cat(grads)

To test this you can do something like this (based on deephit.ipynb example notebook)

model = DeepHitSingle(net, tt.optim.Adam, alpha=0.2, sigma=0.1, duration_index=labtrans.cuts)
gradients = Gradients()
callbacks = [gradients]
log = model.fit(x_train, y_train, batch_size, epochs=1, callbacks=callbacks)
print(gradients.grads)
havakv commented 2 years ago

For the second part of your question "how I can calculate the loss for a given input?" the simplest way it to call

model.score_in_batches(x_train, y_train)

Note that default behaviour is to set the net in eval mode, meaning dropout is deterministic. Also, gradients are not computed with this call

ShadiRahimian commented 2 years ago

@havakv thanks a lot for the detailed and clear explanation :) works perfectly now!