havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
779 stars 180 forks source link

Time to event prediction using Graph Neural Networks #92

Open paulamartingonzalez opened 2 years ago

paulamartingonzalez commented 2 years ago

Hi! I am trying to extend the tutorials 3 and 4 to implement the Logistic Hazard loss in a Graph Neural Network for Graph-level prediction of survival. This is the example I am following.

Nevertheless, I am getting some errors when training the network. Specifically in "model.fit_dataloader" where I get:

AttributeError                            Traceback (most recent call last)
<ipython-input-66-7ae9e531dd45> in <module>()
      2 epochs = 50
      3 verbose = True
----> 4 log = model.fit_dataloader(dl_train, epochs, callbacks, verbose, val_dataloader=dl_test)

13 frames
/usr/local/lib/python3.7/dist-packages/torchtuples/tupletree.py in shapes_of(data)
    109 def shapes_of(data):
    110     """Apply x.shape to elemnts in data."""
--> 111     return data.shape
    112 
    113 @apply_leaf

AttributeError: 'DGLHeteroGraph' object has no attribute 'shape'

Do you have any suggestions on how to fix this error and/or approach this issue?

Thanks!

havakv commented 2 years ago

HI! I'm not that familiar with Graph Neural Networks, so I'm not really sure what the issue of the dataloader is. Could you maybe just post the full example so I can run it myself and reproduce the error?

When that is said, the API's dependence on torchtuples makes it hard to use for complicated work, so we've been working on a new version that simplifies the use of just pure pytorch. If you're interested in trying that instead, you can find that by checking out the refactor_out_torchtuples branch and look at the example for the Logistic hazard there https://github.com/havakv/pycox/blob/refactor_out_torchtuples/examples/torch_logistic_hazard.py

havakv commented 2 years ago

After looking into the Graph classification example you linked to, I'm guessing you have defined a collate_fn like the one in the pycox tutorial 4:

def collate_fn(batch):
    """Stacks the entries of a nested tuple"""
    return tt.tuplefy(batch).stack()

that gives you the error. If that is the case, this is because each batch from the train_dataloader consists of a dgl.heterograph.DGLHeteroGraph. The pycox tutorial was written expecting the input to be torch.Tensor (or a list/tuple of such).

So what I think you need to do is to write a dataloader that gives you an output for next(iter(train_dataloader) that is a tuple with types (dgl.heterograph.DGLHeteroGraph, (torch.Tensor, torch.Tensor)) containing the relevant data

paulamartingonzalez commented 2 years ago

Hi @havakv , many thanks for your help!

I couldn't share the above code so I decided to switch to public data and make it easier for us to debug 😺

I found that pytorch geometric (another library for graph learning) has the MNIST dataset and I tried to expand your tutorial as follows. I get an error on the last line (i.e. model = LogisticHazard(GCN, tt.optim.Adam(0.01), duration_index=labtrans.cuts)) that says:

TypeError                                 Traceback (most recent call last)
<ipython-input-27-c2d6a04a7324> in <module>()
----> 1 model = LogisticHazard(GCN, tt.optim.Adam(0.01), duration_index=labtrans.cuts)

2 frames
/usr/local/lib/python3.7/dist-packages/torchtuples/base.py in optimizer(self, optimizer)
    106         self._optimizer = optimizer
    107         if callable(self._optimizer):
--> 108             self._optimizer = self._optimizer(params=self.net.parameters())
    109         if not isinstance(self._optimizer, OptimWrap):
    110             self._optimizer = OptimWrap(self._optimizer)

TypeError: parameters() missing 1 required positional argument: 'self'

Do you have any hint of what may be happening? I am happy to help debugging the graph learning part if it helps in any way. Many thanks in advance!


code:

import numpy as np
import os
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader 

# MNIST is part of torchvision
from torchvision import datasets, transforms

import torchtuples as tt
from pycox.models import LogisticHazard
from pycox.utils import kaplan_meier
from pycox.evaluation import EvalSurv
import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Dataset, Data
import numpy as np 
import os
from tqdm import tqdm
import pandas as pd
from scipy.spatial.distance import squareform
from scipy.spatial.distance import pdist
import time
from scipy.sparse import csr_matrix
import torch 
from torch_geometric.datasets import TUDataset,MNISTSuperpixels

mnist_train = MNISTSuperpixels(root='data/MNISTSuperpixels', train=True)
mnist_test = MNISTSuperpixels(root='data/MNISTSuperpixels', train=False)

def sim_event_times(mnist, max_time=700):
    digits = []
    for i in range(len(mnist)):
      digits.append(mnist.get(i).y.numpy()[0])
    digits = np.asarray(digits)
    betas = 365 * np.exp(-0.6 * digits) / np.log(1.2)
    event_times = np.random.exponential(betas)
    censored = event_times > max_time
    event_times[censored] = max_time
    return tt.tuplefy(event_times, ~censored)

sim_train = sim_event_times(mnist_train)
sim_test = sim_event_times(mnist_test)

labtrans = LogisticHazard.label_transform(20)
target_train = labtrans.fit_transform(*sim_train)
target_test = labtrans.transform(*sim_test)

class MnistSimDatasetSingle(Dataset):
    """Simulatied data from MNIST. Read a single entry at a time.
    """
    def __init__(self, mnist_dataset, time, event, root, test = False, transform=None, pre_transform=None):
        self.mnist_dataset = mnist_dataset
        self.time, self.event = tt.tuplefy(time, event).to_tensor()
        self.test = test
        self.root = root
        #self.num_node_features = mnist_dataset.num_node_features

    def __len__(self):
        return len(self.mnist_dataset)

    def __getitem__(self, index):

        """ - Equivalent to __getitem__ in pytorch
            - Is not needed for PyG's InMemoryDataset
        """
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'test.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'training.pt'))   
        return data

dataset_train = MnistSimDatasetSingle(mnist_train, *target_train,root='data/MNISTSuperpixels', test=False)
dataset_test = MnistSimDatasetSingle(mnist_test, *target_test,root='data/MNISTSuperpixels', test=True)
batch_size = 128

def collate_fn(batch):
    """Stacks the entries of a nested tuple"""
    return tt.tuplefy(batch).stack()

dl_train = DataLoader(dataset_train, batch_size, shuffle=True, collate_fn=collate_fn)
dl_test = DataLoader(dataset_test, batch_size, shuffle=False, collate_fn=collate_fn)

from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x

model = GCN(hidden_channels=64)
print(model)
model = LogisticHazard(GCN, tt.optim.Adam(0.01), duration_index=labtrans.cuts)
havakv commented 2 years ago

Hi again! Happy to help!

So, first of all, make sure you clear your environment and run the code before you submit it here, ensuring it produces the same error. Because here you have a dataset variable that is not defined, meaning self.conv1 = GCNConv(dataset.num_node_features, hidden_channels) and self.lin = Linear(hidden_channels, dataset.num_classes) doesn't work.

The error you get TypeError: parameters() missing 1 required positional argument: 'self' is just because you pass the class GCN to LogisticHazard rather than an GCN object.

Try this instead:

net = GCN(hidden_channels=64)
print(net)
model = LogisticHazard(net, tt.optim.Adam(0.01), duration_index=labtrans.cuts)
paulamartingonzalez commented 2 years ago

Thanks again @havakv ! I have added here the python notebook with the cleaned code that I am using. I am running these trials in Google Colab. Does it work for you now?

I fixed the issues above but the now the session crashes without an apparent error while executing the last line (log = model.fit_dataloader(dl_train, epochs, callbacks, verbose, val_dataloader=dl_test) . Does the code look good to you? I am afraid that the dimensions or something may be wrong. Any other suggestions are welcomed!

havakv commented 2 years ago

Hmm, so to me it looks like you're loading the whole graph in your MnistSimDatasetSingle.__getitem__, which should be used to get parts of the dataset instead. I.e., you should be able to get a singe element by dataset_train[index], but now you get the whole thing. Also, the MnistSimDatasetSingle.__getitem__ should also return the relevant time and event (which it currently doesn't).

So my suggestion for how you should move forward is to check each subpart to get a better understanding for where the errors are. So I would suggest you do the following:

  1. Make sure you can call dataset_train[index] for some index such as 0, 1, 5. and that should return a tuple of size two. So for (x, y) = dataset_train[0], you should get an y that contains your (time, event) and an x that contains the data you want to pass to GCN.forward.
  2. Fix the collate_fn such that it works the way you want. You should be able to call
    batch = [dataset_train[0], dataset_train[1]]
    collate_fn(batch)

    and it should return something reasonable.

  3. Make your DataLoader and test that it works with batch = next(iter(dl_train)).
  4. Test you net with
    net = GCN(hidden_channels=10)
    batch = next(iter(dl_train))
    x, y = batch
    net(*x)
  5. Try to fit the model with the LogisticHazard.
paulamartingonzalez commented 2 years ago

Thanks! I have updated the Dataloader accordingly to return one slice at a time and a tuple with (graph data object,(time, event)).

I am having issues now with the collate function. I am using the same one that you had in the tutorial as I think we want to stack the tuples but it gives me:

 AttributeError                            Traceback (most recent call last)
<ipython-input-80-43870e61dafc> in <module>()
      1 batch = [dataset_train[0], dataset_train[1]]
----> 2 collate_fn(batch)

9 frames
/usr/local/lib/python3.7/dist-packages/torchtuples/tupletree.py in shapes_of(data)
    109 def shapes_of(data):
    110     """Apply x.shape to elemnts in data."""
--> 111     return data.shape
    112 
    113 @apply_leaf

AttributeError: 'Data' object has no attribute 'shape'

Any other hint? Apologies as I am not very familiar using tuples in this setting!

havakv commented 2 years ago

No need to apologise . The reason why it doesn't work is because your Dataset doesn't return tuples with torch.Tensor object. This is likely because you are working with a graph dataset. The job of the collate_fn is to take a list of data and combine it. Typically you have a list of batch = [a_1, a_2, a_3] where each a_i contains you input_i and target_i. Your collate_fn should then return a single tuple (input, target). I'll illustrate with an example where input_i and target_i are numpy arrays:

input_i = np.arange(5)
target_i = np.array([0])

# each a_i represent what you would get with dataset_train[i]
a_1 = (input_i, target_i)
a_2 = (input_i + 2, target_i + 2)
a_3 = (input_i + 3, target_i + 3)

# the batch represtent what your dataloader passes to the collate_fn if batch_size=3
batch = [a_1, a_2, a_3]

input = np.stack([a[0] for a in batch])
target = np.stack([a[1] for a in batch])

# you collate_fn can the return something along the lines of
(torch.from_numpy(input), torch.from_numpy(target))

In your case, you don't have numpy arrays or torch tensors, but some graph object. This complicates things, as you can't use any of the standard collate_fn. You need to figure out how to combine multiple of these graph object into a batch called input and can then create a collate_fn based on that

paulamartingonzalez commented 2 years ago

Hi @havakv! Sorry for my radiosilence, I was looking for the best way to proceed with this! I decided to go back to the DGL graph library since I feel I have more control on the datasets and collate function. I generated some random data to be able to share this with you. You can find here the notebook.

I think now the data loaders work as expected with the custom collate function. I've indeed done the test of the GNN as you indicated above and it works!

Nevertheless, when I try to fit the LogisticHazard model, it raises some errors in the last two cells:

with model = LogisticHazard(net, tt.optim.Adam(0.01), duration_index=labtrans.cuts)

pred = model.predict(batch[0]) gives:

/usr/local/lib/python3.7/dist-packages/torchtuples/base.py in _predict_func(self, func, input, batch_size, numpy, eval_, grads, to_cpu, num_workers, is_dataloader, **kwargs)
    441         else:
    442             raise ValueError("Did not recognize data type. You can set "is_dataloader to Ture" +
--> 443                              + " or False to force usage.")
    444 
    445         to_cpu = numpy or to_cpu

TypeError: bad operand type for unary +: 'str'

and log = model.fit_dataloader(data_loader, epochs, callbacks, verbose, val_dataloader=test_data_loader) gives

/usr/local/lib/python3.7/dist-packages/torchtuples/tupletree.py in to_device(data, device)
    287     """
    288     if type(data) is not torch.Tensor:
--> 289         raise RuntimeError(f"Need 'data' to be tensors, not {type(data)}.")
    290     return data.to(device)
    291 

RuntimeError: Need 'data' to be tensors, not <class 'dgl.heterograph.DGLHeteroGraph'>.

Do you have any hint or thought into how to proceed? Thanks in advance!

havakv commented 2 years ago

Hi again! No need to apologise, these things take some time.

First of all, you've definitely found a bug in torchtuples, so thank you for that! The fix is in https://github.com/havakv/torchtuples/pull/22, so you need to reinstall or update torchtuples to the newest version.

So, in your code, it doesn't look like your collate function works correctly. If you try different batch sizes in you dataloader, you will always get the same size out of it. By changing to the following, you'll get the correct number of labels:

def collate(samples):
    graphs, labels = map(list, zip(*samples)) # you could just do zip(*samples)
    batched_graph = dgl.batch(graphs)
    return batched_graph, tt.tuplefy(labels).stack()

Next, they way you build you net, everything will be simpler if you use this instead:

    def forward(self, g):
        in_feat = g.ndata['h_n'].float()
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')

because you can now use call

net = GCN(4, 16, 1)
batch = next(iter(data_loader))
x, y = batch
net(x)

If you don't like this and want to keep our old call signature of forward(g, in_feat), you need to change you dataloader to output an x = (g, in_feat).

Finally, you network has the wrong number of outputs (it only has one). You can fix this by

net = GCN(4, 16, labtrans.out_features)

If you implement these changes, you should be able to train your networks!

Note however that the model.predict(x) still won't work, because you have a quite complicated dataloader and collate function. The way to approach this is to instead use a dataloader which outputs only x (let's call it dataloader_x) and call model.predict(dataloader_x).

paulamartingonzalez commented 2 years ago

Dear @havakv many thanks again for your help! The training works now perfectly :) I think we're almost there!

I am trying to do the prediction and evaluation of the test set as you mentioned above with a data loader that only has the inputs but it is triggering the following error in torch tuples:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-31-33c6fb4815aa> in <module>()
----> 1 model.predict(test_data_loader_x)

3 frames
/usr/local/lib/python3.7/dist-packages/torchtuples/base.py in _predict_func_dl(self, func, dataloader, numpy, eval_, grads, to_cpu)
    471             if data is not None:
    472                 input = tuplefy(data)
--> 473                 input_train = self.fit_info["input"]
    474                 if input.to_levels() != input_train["levels"]:
    475                     warnings.warn(

KeyError: 'input'

I've added all the new code in the notebook in case you want to replicate this yourself! Do you see something odd? Thanks!

havakv commented 2 years ago

This looks like another bug in torchtuples! Should be fixed in https://github.com/havakv/torchtuples/pull/23. I'll deploy the changes, so you need to reinstall and make sure you have version 0.2.2 of torchtuples. Then your notebook should run fine.

paulamartingonzalez commented 2 years ago

This is amazing, thanks again @havakv! I confirm everything works smoothly now :)

I will be implementing this with my own data, so I might reopen this or a new issue if I have new problems! The notebook is accesible here if someone else wanted to do something similar!

paulamartingonzalez commented 2 years ago

Hi @havakv ! I have a small question again 😸 I am trying to implement the continuous version so I have one less hyper parameter to tune when moving to my analysis and unfortunately I am getting an error when trying to run it.

I get:

/usr/local/lib/python3.7/dist-packages/torchtuples/base.py in fit_dataloader(self, dataloader, epochs, callbacks, verbose, metrics, val_dataloader)
    234                     break
    235                 self.optimizer.zero_grad()
--> 236                 self.batch_metrics = self.compute_metrics(data, self.metrics)
    237                 self.batch_loss = self.batch_metrics["loss"]
    238                 self.batch_loss.backward()

/usr/local/lib/python3.7/dist-packages/pycox/models/cox_cc.py in compute_metrics(self, input, metrics)
     48         batch_size = input.lens().flatten().get_if_all_equal()
     49         if batch_size is None:
---> 50            raise RuntimeError("All elements in input does not have the same length.")
     51         case, control = input # both are TupleTree
     52         input_all = tt.TupleTree((case,) + control).cat()

RuntimeError: All elements in input does not have the same length.

when executing the line log = model.fit_dataloader(data_loader,epochs, callbacks, verbose, val_dataloader=test_data_loader)

In the tutorial you use the model.fit method so maybe that is the issue? It does appear that there is a fit_dataloader method but it doesn't accept any batch_size input I think while that is triggering the error. Or perhaps there is an issue in my dataloader?

I've changed the way I create the target as well as mentioned in the cox-time tutorial. I've added here the analysis in case you wanted to reproduce the error.

havakv commented 2 years ago

Hi again @paulamartingonzalez . So the CoxTime model has a bit of a complicated dataloader as it requires some ordering and sampling. If you want to get it done, you would have to start understanding the current dataloader https://github.com/havakv/pycox/blob/master/pycox/models/data.py#L81 and how the CoxTime model use it. However, I not sure I would advice you to go down that road. I'm sure it is possible, but it is much more complicated than any of the other models i pycox.

paulamartingonzalez commented 2 years ago

Thanks @havakv ! That makes a lot of sense, I'll try with the discrete method for now 😄 Is there any rule of thumb to discretise the data? Or should I treat this as an additional hyper parameter in the models? Thanks for everything!

havakv commented 2 years ago

Not really any rule of thumb as I know of, so you should just treat it as a hyperparameter as you suggest. Happy to help

paulamartingonzalez commented 2 years ago

Hi @havakv ! I am back with this project and I was wondering what's the best way to save a trained model (the "model" variable in the notebook). I've tried pickle but it doesn't work, thanks in advance!

havakv commented 2 years ago

Hi.The simplest way to save a network is to use model.save_net('mynet.pt'). To load the net, you can either use the model.load_net('mynet.pt') method, or simply specify it in the constructor, e.g. model = LogisticHazard('mynet.pt').

This should at least work for regular pytorch. Alternatively, you can look at the model.save_model_weights and model.load_model_weights methods that only saves/loades the weights of the network (and not the actual network architecture).