havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
811 stars 190 forks source link

How to feed the model multiple inputs? #109

Open kyuchoi opened 2 years ago

kyuchoi commented 2 years ago

Hi Havakv, thank you so much for your great work! I'd like to know how to feed the pycox model multiple inputs? Specifically, I want to use image and gene data x1, and x2 of subjects as multiple inputs of one neural network model. I guess it is possible via using torchtuples, however, I could not find the appropriate code snippets from your github repositories.

Specifically, in https://github.com/havakv/pycox/blob/master/examples/04_mnist_dataloaders_cnn.ipynb,

when I change the dataset_train = MnistSimDatasetSingle(mnist_train, *target_train) into dataset_train = MnistSimDatasetSingle(*mnist_train, *target_train) , where mnist_train = (img, x), and x = [x1,x2] vector , then I get this error: TypeError: forward() takes 2 positional arguments but 3 were given

image

More specifically, I am asking the case of both input and output consists of tuples as follows: (img, x), (duration, event)

I guess it should be possible to get input as list, not just a tensor (x0), considering the torchtuple example below.

image

Could you please give us a short example, like 2-3 line dummy codes? Thank you so much! Best, Kyu

havakv commented 2 years ago

Hi, and thank you for the kind words!

So I'm not really sure which part of the code your error is related to, so I won't be able to debug your example unless you post your actual code. But to me it looks like you were able to get the dataloader working, and might have some issues with the network? I'll post all parts anyways.

In this example I'll just use the MNIST images twice, so that our input is (img, img2) (which doesn't make any sense, but serves as an example). We make a dataset that contains both images, and returns the (input, output) tuple ((img, img2), (time, event)). If you have an image input and a vector input, you should alter the example accordingly

class MnistSimDatasetSingle2(Dataset):
    """Simulatied data from MNIST. Read a single entry at a time.
    """
    def __init__(self, mnist_dataset, mnist_dataset2, time, event):
        self.mnist_dataset = mnist_dataset
        self.mnist_dataset2 = mnist_dataset2
        self.time, self.event = tt.tuplefy(time, event).to_tensor()

    def __len__(self):
        return len(self.mnist_dataset)

    def __getitem__(self, index):
        if type(index) is not int:
            raise ValueError(f"Need `index` to be `int`. Got {type(index)}.")
        img = self.mnist_dataset[index][0]
        img2 = self.mnist_dataset2[index][0]
        return (img, img2), (self.time[index], self.event[index])

We can use the collate function from before, but I'll just post i here for completeness

def collate_fn(batch):
    """Stacks the entries of a nested tuple"""
    return tt.tuplefy(batch).stack()

We can now make a dataset and dataloader

dataset_train = MnistSimDatasetSingle2(mnist_train, mnist_train, *target_train)
dl_train = DataLoader(dataset_train, batch_size, shuffle=True, collate_fn=collate_fn)

The network in this example will take the two images as input and I've just added them together (again a stupid example, but illustrative)

class Net(nn.Module):
    def __init__(self, out_features):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 5, 1)
        self.max_pool = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16, 16, 5, 1)
        self.glob_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(16, 16)
        self.fc2 = nn.Linear(16, out_features)

    def forward(self, img, img2):
        x = img + img2 # just to do something
        x = F.relu(self.conv1(x))
        x = self.max_pool(x)
        x = F.relu(self.conv2(x))
        x = self.glob_avg_pool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

We make a simple LogisticHazard model

net = Net(labtrans.out_features)
model = LogisticHazard(net, tt.optim.Adam(0.01), duration_index=labtrans.cuts)

and we test that it works with a batch from our dataloader

batch = next(iter(dl_train))
input, target = batch
pred = model.predict(input)

Does this answer your question?