soda-inria / hazardous

Competing Risks and Survival Analysis
https://soda-inria.github.io/hazardous/
MIT License
46 stars 11 forks source link

FEAT Add SurvTRACE #15

Closed Vincent-Maladiere closed 8 months ago

Vincent-Maladiere commented 11 months ago

This PR aims at refactoring and packaging the SurvTRACE model. This effort focuses on:

cc @ogrisel

Vincent-Maladiere commented 10 months ago

This a "working WIP", upon running fit I get similar performances for training and validation errors as in the original SurvTRACE version.

Next steps:

Usage (with the seer dataset named "seer_cancer_cardio_raw_data.txt"):

from hazardous.data._seer import (
    load_seer,
    CATEGORICAL_FEATURES,
    NUMERICAL_FEATURES,
)
from hazardous.survtrace._model import SurvTRACE

X, y = load_seer("hazardous/data/seer_cancer_cardio_raw_data.txt")
print(X.shape, y.shape)  # (476746, 28), (476746, 2)

model = SurvTRACE(
    numerical_features=NUMERICAL_FEATURES,
    categorical_features=CATEGORICAL_FEATURES,
)
model.fit(X, y)

cc @ogrisel, I think this is reviewable :)

Vincent-Maladiere commented 9 months ago

This PR is being split:

ogrisel commented 8 months ago

Let's close this and later open a new PR dedicated to the model then.