Closed Vincent-Maladiere closed 8 months ago
This a "working WIP", upon running fit I get similar performances for training and validation errors as in the original SurvTRACE version.
Next steps:
Usage (with the seer dataset named "seer_cancer_cardio_raw_data.txt"):
from hazardous.data._seer import (
load_seer,
CATEGORICAL_FEATURES,
NUMERICAL_FEATURES,
)
from hazardous.survtrace._model import SurvTRACE
X, y = load_seer("hazardous/data/seer_cancer_cardio_raw_data.txt")
print(X.shape, y.shape) # (476746, 28), (476746, 2)
model = SurvTRACE(
numerical_features=NUMERICAL_FEATURES,
categorical_features=CATEGORICAL_FEATURES,
)
model.fit(X, y)
cc @ogrisel, I think this is reviewable :)
This PR is being split:
icml-2024
Let's close this and later open a new PR dedicated to the model then.
This PR aims at refactoring and packaging the SurvTRACE model. This effort focuses on:
STConfig
global dictionary. We want to remove this global config and use hyper-parameters the scikit-learn way.cc @ogrisel