loft-br / xgboost-survival-embeddings

Improving XGBoost survival analysis with embeddings and debiased estimators
https://loft-br.github.io/xgboost-survival-embeddings/
Apache License 2.0
320 stars 53 forks source link

Missing self.enable_categorical in python scripts #76

Closed krittaprot closed 1 month ago

krittaprot commented 2 months ago

Code sample

Create and initialize the XGBSEDebiasedBCE model

xgbse_model = XGBSEDebiasedBCE( xgb_params=None, lr_params=lr_params, # Pass the learning rate parameters enable_categorical=True, n_jobs=8 )

Fit the XGBSE survival model with additional parameters

xgbse_model.fit( X_train, y_train, time_bins=TIME_BINS, early_stopping_rounds=20 )



### Problem description
When I tried to fit the model using categorical variables, I found out that enabling it causes error due to lack of declaration of self.enable_categorical in the relevant python scripts.

### Expected behavior
There should be no error and the model fitting should work.

### Possible solutions
Declare self.enable_categorical = enable_categorical in the relevant python scripts (e.g., inside the xgbse/_debiased_bce.py)
GabrielGimenez commented 1 month ago

Added on the latest version

krittaprot commented 2 weeks ago

Do you plan to release the new version soon? Looking forward to updating the package through the main branch.

GabrielGimenez commented 2 weeks ago

@krittaprot It's already released, did you have problems with the categorical features support?

krittaprot commented 2 weeks ago

@GabrielGimenez Yes, I reported this issue because the categorical features support was not working. I forked the main branch, added "self.enable_categorical = enable_categorical" in a few places and it worked.

You can refer to the lines of code I added here: Link

Appreciate your time looking into this as I found the xgbse package to work really well for my use case (time-to-event modeling)!

GabrielGimenez commented 2 weeks ago

@krittaprot I'm still not sure what's the exact issue you are experiencing, here's an example that's using categorical features that's running on the current xgbse version published.

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

from xgbse import XGBSEDebiasedBCE

def sample_data():
    np.random.seed(42)
    n_samples = 1000

    # Generate features
    X = pd.DataFrame(
        {
            "numeric1": np.random.normal(0, 1, n_samples),
            "numeric2": np.random.normal(0, 1, n_samples),
            "categorical1": pd.Categorical(
                np.random.choice(["A", "B", "C"], n_samples)
            ),
            "categorical2": pd.Categorical(
                np.random.choice(["X", "Y", "Z"], n_samples)
            ),
        }
    )

    # Generate survival times and events
    T = np.random.exponential(scale=1, size=n_samples)
    E = np.random.binomial(n=1, p=0.7, size=n_samples)

    y = np.array(list(zip(E, T)), dtype=[("E", bool), ("T", float)])

    return train_test_split(X, y, test_size=0.2, random_state=42)

def test_xgbse_debiased_bce_with_categorical():
    X_train, X_test, y_train, y_test = sample_data()

    model = XGBSEDebiasedBCE(n_jobs=-1, enable_categorical=True)
    model.fit(X_train, y_train)

    preds = model.predict(X_test)

    assert isinstance(preds, pd.DataFrame)
    assert preds.shape[0] == X_test.shape[0]
    assert (preds.values >= 0).all() and (preds.values <= 1).all()

if you still having problems, can you give me an reproducible example for me to debug?

krittaprot commented 2 weeks ago

@GabrielGimenez I created a new venv using Python 3.9 and installed xgbse 0.3.1 for testing on a Macbook Pro M3. I ran the code you provided in a jupyter notebook and was able to generate the synthetic data but got this error after the model fitting.

def sample_data():
    np.random.seed(42)
    n_samples = 1000

    # Generate features
    X = pd.DataFrame(
        {
            "numeric1": np.random.normal(0, 1, n_samples),
            "numeric2": np.random.normal(0, 1, n_samples),
            "categorical1": pd.Categorical(
                np.random.choice(["A", "B", "C"], n_samples)
            ),
            "categorical2": pd.Categorical(
                np.random.choice(["X", "Y", "Z"], n_samples)
            ),
        }
    )

    # Generate survival times and events
    T = np.random.exponential(scale=1, size=n_samples)
    E = np.random.binomial(n=1, p=0.7, size=n_samples)

    y = np.array(list(zip(E, T)), dtype=[("E", bool), ("T", float)])

    return train_test_split(X, y, test_size=0.2, random_state=42)

X_train, X_test, y_train, y_test = sample_data()

print(X_train.head())

print('Below are the dimensions of the synthetic dataset')
print(f'X_train shape: {X_train.shape}, X_test shape: {X_test.shape}, y_train shape: {y_train.shape}, y_test shape: {y_test.shape}')
     numeric1  numeric2 categorical1 categorical2
29  -0.291694 -1.022793            C            Y
535  0.047399 -1.594703            C            Z
695 -0.309546  1.938929            A            Y
557 -0.432558 -0.803179            C            X
836  1.550500 -1.143726            A            Y
Below are the dimensions of the synthetic dataset
X_train shape: (800, 4), X_test shape: (200, 4), y_train shape: (800,), y_test shape: (200,)
model = XGBSEDebiasedBCE(n_jobs=-1, enable_categorical=True)
model.fit(X_train, y_train)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File ~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/IPython/core/formatters.py:974, in MimeBundleFormatter.__call__(self, obj, include, exclude)
    [971](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/IPython/core/formatters.py:971)     method = get_real_method(obj, self.print_method)
    [973](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/IPython/core/formatters.py:973)     if method is not None:
--> [974](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/IPython/core/formatters.py:974)         return method(include=include, exclude=exclude)
    [975](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/IPython/core/formatters.py:975)     return None
    [976](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/IPython/core/formatters.py:976) else:

File ~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:697, in BaseEstimator._repr_mimebundle_(self, **kwargs)
    [695](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:695) def _repr_mimebundle_(self, **kwargs):
    [696](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:696)     """Mime bundle used by jupyter kernels to display estimator"""
--> [697](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:697)     output = {"text/plain": repr(self)}
    [698](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:698)     if get_config()["display"] == "diagram":
    [699](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:699)         output["text/html"] = estimator_html_repr(self)

File ~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:315, in BaseEstimator.__repr__(self, N_CHAR_MAX)
    [307](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:307) # use ellipsis for sequences with a lot of elements
    [308](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:308) pp = _EstimatorPrettyPrinter(
    [309](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:309)     compact=True,
    [310](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:310)     indent=1,
    [311](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:311)     indent_at_name=True,
    [312](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:312)     n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW,
    [313](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:313) )
--> [315](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:315) repr_ = pp.pformat(self)
    [317](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:317) # Use bruteforce ellipsis when there are a lot of non-blank characters
    [318](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:318) n_nonblank = len("".join(repr_.split()))

File /Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:153, in PrettyPrinter.pformat(self, object)
    [151](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:151) def pformat(self, object):
    [152](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:152)     sio = _StringIO()
--> [153](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:153)     self._format(object, sio, 0, 0, {}, 0)
    [154](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:154)     return sio.getvalue()

File /Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:170, in PrettyPrinter._format(self, object, stream, indent, allowance, context, level)
    [168](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:168)     self._readable = False
    [169](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:169)     return
--> [170](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:170) rep = self._repr(object, context, level)
    [171](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:171) max_width = self._width - indent - allowance
    [172](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:172) if len(rep) > max_width:

File /Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:431, in PrettyPrinter._repr(self, object, context, level)
    [430](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:430) def _repr(self, object, context, level):
--> [431](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:431)     repr, readable, recursive = self.format(object, context.copy(),
    [432](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:432)                                             self._depth, level)
    [433](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:433)     if not readable:
    [434](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/pprint.py:434)         self._readable = False

File ~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:189, in _EstimatorPrettyPrinter.format(self, object, context, maxlevels, level)
    [188](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:188) def format(self, object, context, maxlevels, level):
--> [189](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:189)     return _safe_repr(
    [190](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:190)         object, context, maxlevels, level, changed_only=self._changed_only
    [191](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:191)     )

File ~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:440, in _safe_repr(object, context, maxlevels, level, changed_only)
    [438](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:438) recursive = False
    [439](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:439) if changed_only:
--> [440](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:440)     params = _changed_params(object)
    [441](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:441) else:
    [442](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:442)     params = object.get_params(deep=False)

File ~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:93, in _changed_params(estimator)
     [89](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:89) def _changed_params(estimator):
     [90](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:90)     """Return dict (param_name: value) of parameters that were given to
     [91](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:91)     estimator with non-default values."""
---> [93](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:93)     params = estimator.get_params(deep=False)
     [94](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:94)     init_func = getattr(estimator.__init__, "deprecated_original", estimator.__init__)
     [95](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/utils/_pprint.py:95)     init_params = inspect.signature(init_func).parameters

File ~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:244, in BaseEstimator.get_params(self, deep)
    [242](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:242) out = dict()
    [243](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:243) for key in self._get_param_names():
--> [244](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:244)     value = getattr(self, key)
    [245](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:245)     if deep and hasattr(value, "get_params") and not isinstance(value, type):
    [246](https://file+.vscode-resource.vscode-cdn.net/Users/11488715/Documents/7_Tutorials/xgbse/~/Documents/7_Tutorials/xgbse/.surv/lib/python3.9/site-packages/sklearn/base.py:246)         deep_items = value.get_params().items()

AttributeError: 'XGBSEDebiasedBCE' object has no attribute 'enable_categorical'

It seems that even after this error the model is still able to predict, I am not sure why.

GabrielGimenez commented 2 weeks ago

Ok, i found the issue. The way that the jupyter notebooks display sklean models make use of the internal atributes.

should be fixed on the 0.3.3 release

krittaprot commented 2 weeks ago
image

For your information, I updated to 0.3.3 and it worked fine now, thank you!