loft-br / xgboost-survival-embeddings

Improving XGBoost survival analysis with embeddings and debiased estimators
https://loft-br.github.io/xgboost-survival-embeddings/
Apache License 2.0
321 stars 53 forks source link

AttributeError: 'XGBSEDebiasedBCE' object has no attribute 'enable_categorical' #79

Closed pedrodicati closed 1 month ago

pedrodicati commented 1 month ago

Code sample and problem

When I was training the model xgbse_model_Debiased = XGBSEDebiasedBCE(enable_categorical=True), I had an attribute error, where it said that this class does not have the enable_categorical attribute. What I found strange is that the class receives this parameter, but the problem seems to have to do with a getattr in sklearn's base.py

image

the all exception:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File \.venv\Lib\site-packages\IPython\core\formatters.py:347, in BaseFormatter.__call__(self, obj)
    345     method = get_real_method(obj, self.print_method)
    346     if method is not None:
--> 347         return method()
    348     return None
    349 else:

File \.venv\Lib\site-packages\sklearn\base.py:693, in BaseEstimator._repr_html_inner(self)
    688 def _repr_html_inner(self):
    689     """This function is returned by the @property `_repr_html_` to make
    690     `hasattr(estimator, "_repr_html_") return `True` or `False` depending
    691     on `get_config()["display"]`.
    692     """
--> 693     return estimator_html_repr(self)

File \.venv\Lib\site-packages\sklearn\utils\_estimator_html_repr.py:363, in estimator_html_repr(estimator)
    361 style_template = Template(_CSS_STYLE)
    362 style_with_id = style_template.substitute(id=container_id)
--> 363 estimator_str = str(estimator)
    365 # The fallback message is shown by default and loading the CSS sets
    366 # div.sk-text-repr-fallback to display: none to hide the fallback message.
    367 #
   (...)
    372 # The reverse logic applies to HTML repr div.sk-container.
    373 # div.sk-container is hidden by default and the loading the CSS displays it.
    374 fallback_msg = (
    375     "In a Jupyter environment, please rerun this cell to show the HTML"
    376     " representation or trust the notebook. <br />On GitHub, the"
    377     " HTML representation is unable to render, please try loading this page"
    378     " with nbviewer.org."
    379 )

File \.venv\Lib\site-packages\sklearn\base.py:315, in BaseEstimator.__repr__(self, N_CHAR_MAX)
    307 # use ellipsis for sequences with a lot of elements
    308 pp = _EstimatorPrettyPrinter(
    309     compact=True,
    310     indent=1,
    311     indent_at_name=True,
    312     n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW,
    313 )
--> 315 repr_ = pp.pformat(self)
    317 # Use bruteforce ellipsis when there are a lot of non-blank characters
    318 n_nonblank = len("".join(repr_.split()))

File ~\AppData\Local\Programs\Python\Python311\Lib\pprint.py:161, in PrettyPrinter.pformat(self, object)
    159 def pformat(self, object):
    160     sio = _StringIO()
--> 161     self._format(object, sio, 0, 0, {}, 0)
    162     return sio.getvalue()

File ~\AppData\Local\Programs\Python\Python311\Lib\pprint.py:178, in PrettyPrinter._format(self, object, stream, indent, allowance, context, level)
    176     self._readable = False
    177     return
--> 178 rep = self._repr(object, context, level)
    179 max_width = self._width - indent - allowance
    180 if len(rep) > max_width:

File ~\AppData\Local\Programs\Python\Python311\Lib\pprint.py:458, in PrettyPrinter._repr(self, object, context, level)
    457 def _repr(self, object, context, level):
--> 458     repr, readable, recursive = self.format(object, context.copy(),
    459                                             self._depth, level)
    460     if not readable:
    461         self._readable = False

File \.venv\Lib\site-packages\sklearn\utils\_pprint.py:189, in _EstimatorPrettyPrinter.format(self, object, context, maxlevels, level)
    188 def format(self, object, context, maxlevels, level):
--> 189     return _safe_repr(
    190         object, context, maxlevels, level, changed_only=self._changed_only
    191     )

File \.venv\Lib\site-packages\sklearn\utils\_pprint.py:440, in _safe_repr(object, context, maxlevels, level, changed_only)
    438 recursive = False
    439 if changed_only:
--> 440     params = _changed_params(object)
    441 else:
    442     params = object.get_params(deep=False)

File \.venv\Lib\site-packages\sklearn\utils\_pprint.py:93, in _changed_params(estimator)
     89 def _changed_params(estimator):
     90     """Return dict (param_name: value) of parameters that were given to
     91     estimator with non-default values."""
---> 93     params = estimator.get_params(deep=False)
     94     init_func = getattr(estimator.__init__, "deprecated_original", estimator.__init__)
     95     init_params = inspect.signature(init_func).parameters

File \.venv\Lib\site-packages\sklearn\base.py:244, in BaseEstimator.get_params(self, deep)
    242 out = dict()
    243 for key in self._get_param_names():
--> 244     value = getattr(self, key)
    245     if deep and hasattr(value, "get_params") and not isinstance(value, type):
    246         deep_items = value.get_params().items()

AttributeError: 'XGBSEDebiasedBCE' object has no attribute 'enable_categorical'
GabrielGimenez commented 1 month ago

Should be solved on 0.3.3 release