mthorrell / gboost_module

Gradient Boosting Modules for pytorch
MIT License
8 stars 0 forks source link

Add Survival Support #10

Open mthorrell opened 1 week ago

mthorrell commented 1 week ago

Raw code to be optimized

import numpy as np
import pandas as pd

from gboost_module import xgbmodule

def meas_df_to_long_df(
    meas_df: pd.DataFrame,
    resolution: int
):
    assert 'time' in meas_df
    assert 'id' in meas_df
    assert 'event' in meas_df  # conceptually this is somewhat optional
    assert 'max_obs_time' not in meas_df  # don't have columns that would conflict with working columns

    meas_df = meas_df.sort_values(['id', 'time']).copy()

    non_zero_ids = [
        uid for uid in meas_df['id'].drop_duplicates()
        if not (meas_df[meas_df['id'] == uid]['time'] == 0).any()
    ]
    new_rows = []
    for uid in non_zero_ids:
        new_row = meas_df[meas_df['id'] == uid].iloc[0].copy()
        new_row['time'] = 0
        new_row['event'] = False
        new_rows.append(new_row.to_frame().T)
    meas_df = pd.concat(new_rows + [meas_df]).sort_values(['id', 'time']).copy()

    meas_df = meas_df.merge(
        meas_df.groupby('id')['time'].max().rename('max_obs_time').reset_index(),
        on='id',
        how='inner',
        validate='many_to_one'
    )

    meas_df = meas_df.sort_values(['id', 'time'])

    grid = pd.concat([
        pd.Series(np.linspace(
            meas_df['time'].min(),
            meas_df['time'].max(),
            resolution
        )),
        meas_df['time'].copy()
    ]).drop_duplicates().sort_values().to_list()

    long_df = []
    for t in grid:
        sub_df = meas_df[
            (meas_df['time'] <= t) & (meas_df['max_obs_time'] >= t)
        ].sort_values('time', ascending=False).drop_duplicates('id').copy()
        sub_df['measured_time'] = t
        long_df.append(sub_df)
    long_df = pd.concat(long_df).reset_index(drop=True)

    return long_df.sort_values(['id', 'measured_time']).reset_index(drop=True).copy()

class IntegrateSurv(torch.nn.Module):
    def __init__(self, meas_df, resolution, covariates=[], params={}):
        super(IntegrateSurv, self).__init__()

        self.meas_df = meas_df
        self.resolution = resolution
        self.model_cols = covariates + ['measured_time']

        self.long_df, self.id_stats = self.get_input_info_from_meas_df(meas_df)

        self.gb = xgbmodule.XGBModule(
            self.long_df.shape[0], len(self.model_cols), 1, params=params
        )
        self.base_lambda = torch.nn.Parameter(torch.Tensor([0.0]))

        self.test_long_df = None
        self.test_id_stats = None

    def get_input_info_from_meas_df(self, meas_df):
        long_df = meas_df_to_long_df(meas_df, self.resolution).sort_values(
            ['id', 'measured_time']
        ).reset_index(drop=True).copy()

        id_stats = long_df.groupby('id').apply(
            lambda x: pd.Series({
                'n_preds': x.shape[0],
                'event': x['event'].any()
            })
        ).reset_index().copy()
        id_stats['preds_idx'] = id_stats['n_preds'].cumsum()

        return long_df, id_stats

    def forward(self, long_df=None, use_test_cache=True):
        assert long_df is None or not self.training
        if long_df is not None:
            if not use_test_cache or self.test_long_df is None or self.test_id_stats is None:
                self.test_long_df, self.test_id_stats = self.get_input_info_from_meas_df(
                    long_df
                )

        gb_input = (
            self.long_df[self.model_cols].astype(float) if self.training
            else self.test_long_df[self.model_cols].astype(float)
        )
        hazard_flat = torch.exp(
            self.gb(gb_input)
            + self.base_lambda
        )

        # I wonder if we strategicaly insert zeros here so we can vectorize this
        # thing better
        hazard_by_id = {}
        prev_idx = 0
        all_loss = torch.Tensor([0.0])

        id_info = (
            self.id_stats if self.training
            else self.test_id_stats
        )

        for i, row in id_info.iterrows():
            low_idx = prev_idx
            high_idx = row['preds_idx']

            row_data = {
                'raw_hazard': hazard_flat[low_idx:high_idx,:].flatten(),
                'integrate_hazard': torch.trapezoid(
                    hazard_flat[low_idx:high_idx,:].flatten(),
                    torch.Tensor(
                        np.array(gb_input.iloc[low_idx:high_idx, :]['measured_time'])
                    ).flatten()
                )
            }
            prev_idx = row['preds_idx']

            row_data['prob_hit'] = row_data['raw_hazard'][-1]

            row_data['loss'] = row_data['integrate_hazard']
            if row['event']:
                row_data['loss'] = row_data['loss'] - torch.log(row_data['prob_hit'])
            all_loss = all_loss + row_data['loss']
            hazard_by_id[row['id']] = row_data

        all_loss = all_loss / id_info.shape[0]
        return hazard_by_id, all_loss

    def gb_step(self):
        self.gb.gb_step(self.long_df[self.model_cols].astype(float))

Example dataset

import pandas as pd
input_df = pd.DataFrame([
    {'id': 1, 'time': 0, 'event': False, 'covariate_1': 1},
    {'id': 1, 'time': 4, 'event': False, 'covariate_1': 3},
    {'id': 1, 'time': 6, 'event': True, 'covariate_1': None},
    {'id': 2, 'time': 0, 'event': False, 'covariate_1': 1},
    {'id': 2, 'time': 3, 'event': False, 'covariate_1': None},
    {'id': 2, 'time': 5, 'event': False, 'covariate_1': 2},
])

gboost_module vs. Kaplan Meier when using no covariates on fake data:

image