havakv / pycox

Survival analysis with PyTorch
BSD 2-Clause "Simplified" License
780 stars 180 forks source link

Use clinical data in addition to image dataset for survival estimates #135

Open mahootiha-maryam opened 2 years ago

mahootiha-maryam commented 2 years ago

Hi Havard. If I want to add age to the image for making prediction of survival times, What should I do? First in this class should I give age array in init and getitem and return it with img? Should I use tt.tuplefy for them?

class MnistSimDatasetSingle(Dataset): """Simulatied data from MNIST. Read a single entry at a time. """ def init(self, mnist_dataset, time, event): self.mnist_dataset = mnist_dataset self.time, self.event = tt.tuplefy(time, event).to_tensor()

def __len__(self):
    return len(self.mnist_dataset)

def __getitem__(self, index):
    if type(index) is not int:
        raise ValueError(f"Need `index` to be `int`. Got {type(index)}.")
    img = self.mnist_dataset[index][0]
    return img, (self.time[index], self.event[index])

In this part should I make a tuple of image and age and then use it instead of mnisttrain?

dataset_train = MnistSimDatasetBatch(mnist_train, target_train) dataset_test = MnistSimDatasetBatch(mnist_test, target_test)

And in making a network should I give it as new input to forward and after flattening the x , should I concatenate with age?

class Net(nn.Module): def init(self, out_features): super().init() self.conv1 = nn.Conv2d(1, 16, 5, 1) self.max_pool = nn.MaxPool2d(2) self.conv2 = nn.Conv2d(16, 16, 5, 1) self.glob_avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.fc1 = nn.Linear(16, 16) self.fc2 = nn.Linear(16, out_features)

def forward(self, x):
    x = F.relu(self.conv1(x))
    x = self.max_pool(x)
    x = F.relu(self.conv2(x))
    x = self.glob_avg_pool(x)
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x