lasso-net / lassonet

Feature selection in neural networks
MIT License
215 stars 52 forks source link

Using the cox loss and methods with cutom model #51

Open SalvatoreRa opened 11 months ago

SalvatoreRa commented 11 months ago

Great work,

I found your approach very interesting and I was trying to generalize it to different pytorch architectures

I wanted to test your approach with custom models and other pytorch model. The idea is to basically take a pytorch model (arbitrary architecture) and test the ability to predict survival.

for example, I wanted to test with a simple pytorch model.

let' s say:

now, to better explain there is below:

What I am trying to understand is, considering this case:

taking this dataset and starting from your example:

from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from lassonet import LassoNetCoxRegressor
from lassonet import  plot_path
res_dir = './survival/'
X = np.genfromtxt(res_dir + "hnscc_x.csv", delimiter=",", skip_header=1)
y = np.genfromtxt(res_dir +  "hnscc_y.csv", delimiter=",", skip_header=1)

this is a simple version of the approach modelling the survival as a simple binary classification approach:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, random_split, SubsetRandomSampler, ConcatDataset, Dataset
import pandas as pd
import seaborn as sns
# creating a simple MLP
class FCNNC(nn.Module):
    def __init__(self, input_size, constraint_size, hidden_size, num_classes):
        super(FCNNC, self).__init__()
        self.fc1 = nn.Linear(input_size, constraint_size) 
        self.fc2 = nn.Linear(constraint_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return x

# simple class for the dataset
class DataClassifier(Dataset):  
    def __init__(self, X_train, y_train):
        self.X = torch.from_numpy(X_train.astype(np.float32))
        self.y = torch.from_numpy(y_train).type(torch.LongTensor)
        self.len = self.X.shape[0]

    def __getitem__(self, index):
        return self.X[index], self.y[index]  

    def __len__(self):
        return self.len  
# binary accuracy
def multi_acc(y_pred, y_test):
    _, y_pred = torch.max(y_pred, dim = 1)    

    correct_pred = (y_pred == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)

    acc = torch.round(acc * 100)

    return acc

# transforming in binary classification
batch_size = 2048
X_train, X_test, Y_train, Y_test = train_test_split(X, y[:,1], random_state=0)
traindata = DataClassifier(X_train, Y_train)
trainloader = torch.utils.data.DataLoader(traindata, batch_size=batch_size, shuffle=True)

valdata = DataClassifier(X_test,Y_test)
valloader = torch.utils.data.DataLoader(valdata, batch_size=X_test.shape[0], shuffle=False)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()
model = FCNNC(X.shape[1],20,20,2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

n_epochs =1000
%matplotlib inline

# simple training loop to store results and plotting
accuracy_stats = {
        'train': [],
        "val": []
    }
loss_stats = {
        'train': [],
        "val": []
    }
for epoch in range(n_epochs):
    running_loss = 0.0
    train_epoch_acc = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        model.to(device)
            # set optimizer to zero grad to remove previous epoch gradients
        optimizer.zero_grad()

            # forward propagation
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        acc = multi_acc(outputs, labels)

            # backward propagation
        loss.backward()
            # optimize
        optimizer.step()

        running_loss += loss.item()
        train_epoch_acc += acc.item()

    with torch.no_grad():
        val_epoch_loss = 0
        val_epoch_acc = 0
        model.eval()
        for X_val_batch, y_val_batch in valloader:
            X_val_batch = X_val_batch.to(device)
            y_val_batch = y_val_batch.to(device)

            y_val_pred = model(X_val_batch)

            val_loss = criterion(y_val_pred, y_val_batch)
            val_acc = multi_acc(y_val_pred, y_val_batch)

            val_epoch_loss += val_loss.item()
            val_epoch_acc += val_acc.item()

        loss_stats['train'].append(running_loss/len(trainloader))
        loss_stats['val'].append(val_epoch_loss/len(valloader))
        accuracy_stats['train'].append(train_epoch_acc/len(trainloader))
        accuracy_stats['val'].append(val_epoch_acc/len(valloader))

    if epoch % 50 == True:
        print(f'Epoch {epoch+0:03}: | Train Loss: {running_loss/len(trainloader):.5f} | Val Loss: {val_epoch_loss/len(valloader):.5f} | Train Acc: {train_epoch_acc/len(trainloader):.3f}| Val Acc: {val_epoch_acc/len(valloader):.3f}')

train_val_acc_df = pd.DataFrame.from_dict(accuracy_stats).reset_index().melt(id_vars=['index']).rename(columns={"index":"epochs"})
train_val_loss_df = pd.DataFrame.from_dict(loss_stats).reset_index().melt(id_vars=['index']).rename(columns={"index":"epochs"})
    # Plot the dataframes
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20,7))
sns.lineplot(data=train_val_acc_df, x = "epochs", y="value", hue="variable",  ax=axes[0]).set_title('Train-Val Accuracy/Epoch')
sns.lineplot(data=train_val_loss_df, x = "epochs", y="value", hue="variable", ax=axes[1]).set_title('Train-Val Loss/Epoch')

The idea starting from very simple example to transform a model in able to handle censored data

I was highlighting this code from your repository:

import torch
from sortedcontainers import SortedList

def log_substract(x, y):
    """log(exp(x) - exp(y))"""
    return x + torch.log1p(-(y - x).exp())

def scatter_logsumexp(input, index, *, dim=-1, output_size=None):
    """Inspired by torch_scatter.logsumexp
    Uses torch.scatter_reduce for performance
    """
    max_value_per_index = scatter_reduce(
        input, dim=dim, index=index, output_size=output_size, reduce="amax"
    )
    max_per_src_element = max_value_per_index.gather(dim, index)
    recentered_scores = input - max_per_src_element
    sum_per_index = scatter_reduce(
        recentered_scores.exp(),
        dim=dim,
        index=index,
        output_size=output_size,
        reduce="sum",
    )
    return max_value_per_index + sum_per_index.log()

class CoxPHLoss(torch.nn.Module):
    """Loss for CoxPH model. """

    allowed = ("breslow", "efron")

    def __init__(self, method):
        super().__init__()
        assert method in self.allowed, f"Method must be one of {self.allowed}"
        self.method = method

    def forward(self, log_h, y):
        log_h = log_h.flatten()

        durations, events = y.T

        # sort input
        durations, idx = durations.sort(descending=True)
        log_h = log_h[idx]
        events = events[idx]

        event_ind = events.nonzero().flatten()

        # numerator
        log_num = log_h[event_ind].mean()

        # logcumsumexp of events
        event_lcse = torch.logcumsumexp(log_h, dim=0)[event_ind]

        # number of events for each unique risk set
        _, tie_inverses, tie_count = torch.unique_consecutive(
            durations[event_ind], return_counts=True, return_inverse=True
        )

        # position of last event (lowest duration) of each unique risk set
        tie_pos = tie_count.cumsum(axis=0) - 1

        # logcumsumexp by tie for each event
        event_tie_lcse = event_lcse[tie_pos][tie_inverses]

        if self.method == "breslow":
            log_den = event_tie_lcse.mean()

        elif self.method == "efron":
            # based on https://bydmitry.github.io/efron-tensorflow.html

            # logsumexp of ties, duplicated within tie set
            tie_lse = scatter_logsumexp(log_h[event_ind], tie_inverses, dim=0)[
                tie_inverses
            ]
            # multiply (add in log space) with corrective factor
            aux = torch.ones_like(tie_inverses)
            aux[tie_pos[:-1] + 1] -= tie_count[:-1]
            event_id_in_tie = torch.cumsum(aux, dim=0) - 1
            discounted_tie_lse = (
                tie_lse
                + torch.log(event_id_in_tie)
                - torch.log(tie_count[tie_inverses])
            )

            # denominator
            log_den = log_substract(event_tie_lcse, discounted_tie_lse).mean()

        # loss is negative log likelihood
        return log_den - log_num

def concordance_index(risk, time, event):
    """
    O(n log n) implementation of https://square.github.io/pysurvival/metrics/c_index.html
    """
    assert len(risk) == len(time) == len(event)
    n = len(risk)
    order = sorted(range(n), key=time.__getitem__)
    past = SortedList()
    num = 0
    den = 0
    for i in order:
        num += len(past) - past.bisect_right(risk[i])
        den += len(past)
        if event[i]:
            past.add(risk[i])
    return num / den

Thank you very much

Salvatore

louisabraham commented 11 months ago

Any model can use the CoxPHLoss. The loss will be:

criterion = CoxPHLoss()
loss = criterion(model(X_train[batch]), y_train[batch])

To transform the data, please look at this example: https://github.com/lasso-net/lassonet/blob/master/examples/cox_experiments.py

SalvatoreRa commented 11 months ago

Thank you for your reply,

If I have understood correctly, for a dataset one should do:

X = np.genfromtxt(path_x, delimiter=",", skip_header=1)
y = np.genfromtxt(path_y, delimiter=",", skip_header=1)
X = preprocessing.StandardScaler().fit(X).transform(X)

For instance, this should work if the dataset is save as the provided data: hnscc

Once done that, you can split in training and test set:

''' X_train, X_test, y_train, y_test = train_test_split( X, y, random_state=random_state, stratify=y[:, 1], test_size=0.20 ) '''

Should I do anything for the data loader? and for the NN architecture (like which should be the last layer)?

louisabraham commented 11 months ago

For y you should have both a duration and a boolean column for events. For the data loader, just use mini batches. The last layer should just output a real number, so a Linear layer is good.

SalvatoreRa commented 11 months ago

Thank you,

I have done so:

from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from lassonet import LassoNetCoxRegressor
from lassonet import  plot_path
res_dir = './survival/'
X = np.genfromtxt(res_dir + "hnscc_x.csv", delimiter=",", skip_header=1)
y = np.genfromtxt(res_dir +  "hnscc_y.csv", delimiter=",", skip_header=1)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, random_split, SubsetRandomSampler, ConcatDataset, Dataset
import pandas as pd
import seaborn as sns

class FCNNC(nn.Module):
    def __init__(self, input_size, constraint_size, hidden_size, num_classes):
        super(FCNNC, self).__init__()
        self.fc1 = nn.Linear(input_size, constraint_size) 
        self.fc2 = nn.Linear(constraint_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class DataClassifier(Dataset):  
    def __init__(self, X_train, y_train):
        self.X = torch.from_numpy(X_train.astype(np.float32))
        self.y = torch.from_numpy(y_train.astype(np.float32))
        self.len = self.X.shape[0]

    def __getitem__(self, index):
        return self.X[index], self.y[index]  

    def __len__(self):
        return self.len 

batch_size = 200
X_train, X_test, Y_train, Y_test = train_test_split(X, y, random_state=0)
traindata = DataClassifier(X_train, Y_train)
trainloader = torch.utils.data.DataLoader(traindata, batch_size=batch_size, shuffle=True)

valdata = DataClassifier(X_test,Y_test)
valloader = torch.utils.data.DataLoader(valdata, batch_size=X_test.shape[0], shuffle=False)

n_epochs =1000
%matplotlib inline

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = CoxPHLoss(method="breslow")
model = FCNNC(X.shape[1],20,20,1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)

loss_stats = {
        'train': [],
        "val": []
    }
for epoch in range(n_epochs):
    running_loss = 0.0
    train_epoch_acc = 0
    for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        model.to(device)
            # set optimizer to zero grad to remove previous epoch gradients
        optimizer.zero_grad()

            # forward propagation
        outputs = model(inputs)

        #loss = criterion(outputs, labels)
        loss =criterion(model(inputs), labels)

            # backward propagation
        loss.backward()
            # optimize
        optimizer.step()

        running_loss += loss.item()

    with torch.no_grad():
        val_epoch_loss = 0

        model.eval()
        for X_val_batch, y_val_batch in valloader:
            X_val_batch = X_val_batch.to(device)
            y_val_batch = y_val_batch.to(device)

            y_val_pred = model(X_val_batch)

            val_loss =criterion(model(X_val_batch), y_val_batch)

            val_epoch_loss += val_loss.item()

        loss_stats['train'].append(running_loss/len(trainloader))
        loss_stats['val'].append(val_epoch_loss/len(valloader))

    if epoch % 50 == True:
        print(f'Epoch {epoch+0:03}: | Train Loss: {running_loss/len(trainloader):.5f} | Val Loss: {val_epoch_loss/len(valloader):.5f}')

train_val_loss_df = pd.DataFrame.from_dict(loss_stats).reset_index().melt(id_vars=['index']).rename(columns={"index":"epochs"})
    # Plot the dataframes

sns.lineplot(data=train_val_loss_df, x = "epochs", y="value", hue="variable").set_title('Train-Val Loss/Epoch')

It worked and loss is diminishing. As last question how you evaluate? How you should use with CI index? do you have some suggestions for evaluation?

This is the most basically implementation of neural network with PyTorch, but I think it could work with any custom network (and it is using the data you provide). Do you want I will organize the script as tutorial? could be useful?

louisabraham commented 2 months ago

You might want to look at the new Interval models, see the spinet example.