sktime / skbase

Base classes for creating scikit-learn-like parametric objects, and tools for working with them.
BSD 3-Clause "New" or "Revised" License
18 stars 11 forks source link

[BUG] `skbase`'s pretty printing fails with an `AttributeError: '_VisualBlock' object has no attribute 'objs'` during a nested `Pipeline` call #270

Closed achieveordie closed 5 months ago

achieveordie commented 9 months ago

Describe the bug

I was experimenting with combining sklearn and sktime Pipeline objects where I nested a sklearn Pipeline inside a sktime Pipeline. If I try to visualize the final compose estimator to the iPython console, I get hit with the an AttributeError.

To Reproduce

In a Jupyter code block:

from sklearn.pipeline import Pipeline as sklearnPipeline
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer

from sktime.pipeline import make_pipeline
from sktime.classification.interval_based import TimeSeriesForestClassifier

sklearn_pipeline = sklearnPipeline([
    ('scaler', StandardScaler()),
    ('imputer', SimpleImputer()),
])
final_pipeline = make_pipeline(sklearn_pipeline, TimeSeriesForestClassifier())

final_pipeline

Output

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File [c:\Users\sagar\miniconda3\envs\sktime-dev\lib\site-packages\IPython\core\formatters.py:344](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/IPython/core/formatters.py:344), in BaseFormatter.__call__(self, obj)
    [342](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/IPython/core/formatters.py:342)     method = get_real_method(obj, self.print_method)
    [343](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/IPython/core/formatters.py:343)     if method is not None:
--> [344](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/IPython/core/formatters.py:344)         return method()
    [345](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/IPython/core/formatters.py:345)     return None
    [346](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/IPython/core/formatters.py:346) else:

File [c:\Users\sagar\miniconda3\envs\sktime-dev\lib\site-packages\skbase\base\_base.py:877](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_base.py:877), in BaseObject._repr_html_inner(self)
    [870](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_base.py:870) def _repr_html_inner(self):
    [871](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_base.py:871)     """Return HTML representation of class.
    [872](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_base.py:872) 
    [873](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_base.py:873)     This function is returned by the @property `_repr_html_` to make
    [874](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_base.py:874)     `hasattr(BaseObject, "_repr_html_") return `True` or `False` depending
    [875](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_base.py:875)     on `self.get_config()["display"]`.
    [876](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_base.py:876)     """
--> [877](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_base.py:877)     return _object_html_repr(self)

File [c:\Users\sagar\miniconda3\envs\sktime-dev\lib\site-packages\skbase\base\_pretty_printing\_object_html_repr.py:382](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:382), in _object_html_repr(base_object)
    [371](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:371) fallback_msg = (
    [372](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:372)     "Please rerun this cell to show the HTML repr or trust the notebook."
    [373](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:373) )
    [374](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:374) out.write(
    [375](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:375)     f"<style>{style_with_id}</style>"
    [376](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:376)     f'<div id={container_id!r} class="sk-top-container">'
   (...)
    [380](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:380)     '<div class="sk-container" hidden>'
    [381](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:381) )
--> [382](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:382) _write_base_object_html(
    [383](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:383)     out,
    [384](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:384)     base_object,
    [385](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:385)     base_object.__class__.__name__,
    [386](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:386)     base_object_str,
    [387](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:387)     first_call=True,
    [388](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:388) )
    [389](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:389) out.write("</div></div>")
    [391](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:391) html_output = out.getvalue()

File [c:\Users\sagar\miniconda3\envs\sktime-dev\lib\site-packages\skbase\base\_pretty_printing\_object_html_repr.py:147](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:147), in _write_base_object_html(out, base_object, base_object_label, base_object_label_details, first_call)
    [145](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:145)         # wrap element in a serial visualblock
    [146](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:146)         serial_block = _VisualBlock("serial", [est], dash_wrapped=False)
--> [147](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:147)         _write_base_object_html(out, serial_block, name, name_details)
    [148](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:148)         out.write("</div>")  # sk-parallel-item
    [150](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:150) out.write("</div></div>")

File [c:\Users\sagar\miniconda3\envs\sktime-dev\lib\site-packages\skbase\base\_pretty_printing\_object_html_repr.py:142](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:142), in _write_base_object_html(out, base_object, base_object_label, base_object_label_details, first_call)
    [140](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:140) for est, name, name_details in est_infos:
    [141](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:141)     if kind == "serial":
--> [142](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:142)         _write_base_object_html(out, est, name, name_details)
    [143](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:143)     else:  # parallel
    [144](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:144)         out.write('<div class="sk-parallel-item">')

File [c:\Users\sagar\miniconda3\envs\sktime-dev\lib\site-packages\skbase\base\_pretty_printing\_object_html_repr.py:147](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:147), in _write_base_object_html(out, base_object, base_object_label, base_object_label_details, first_call)
    [145](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:145)         # wrap element in a serial visualblock
    [146](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:146)         serial_block = _VisualBlock("serial", [est], dash_wrapped=False)
--> [147](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:147)         _write_base_object_html(out, serial_block, name, name_details)
    [148](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:148)         out.write("</div>")  # sk-parallel-item
    [150](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:150) out.write("</div></div>")

File [c:\Users\sagar\miniconda3\envs\sktime-dev\lib\site-packages\skbase\base\_pretty_printing\_object_html_repr.py:142](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:142), in _write_base_object_html(out, base_object, base_object_label, base_object_label_details, first_call)
    [140](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:140) for est, name, name_details in est_infos:
    [141](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:141)     if kind == "serial":
--> [142](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:142)         _write_base_object_html(out, est, name, name_details)
    [143](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:143)     else:  # parallel
    [144](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:144)         out.write('<div class="sk-parallel-item">')

File [c:\Users\sagar\miniconda3\envs\sktime-dev\lib\site-packages\skbase\base\_pretty_printing\_object_html_repr.py:138](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:138), in _write_base_object_html(out, base_object, base_object_label, base_object_label_details, first_call)
    [136](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:136) kind = est_block.kind
    [137](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:137) out.write(f'<div class="sk-{kind}">')
--> [138](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:138) est_infos = zip(est_block.objs, est_block.names, est_block.name_details)
    [140](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:140) for est, name, name_details in est_infos:
    [141](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/lib/site-packages/skbase/base/_pretty_printing/_object_html_repr.py:141)     if kind == "serial":

AttributeError: '_VisualBlock' object has no attribute 'objs'

ClassifierPipeline(classifier=TimeSeriesForestClassifier(),
                   transformers=[TabularToSeriesAdaptor(transformer=Pipeline(steps=[('scaler', StandardScaler()), ('imputer', SimpleImputer())])),
                                 ColumnConcatenator()])

Expected behavior

If such a compose estimator has no good representation, I'd expect it to simply print the signature rather than throwing an error when it fails to provide the visual representation.

Environment

# sktime.show_versions()

System:
    python: 3.8.17 (default, Jul  5 2023, 20:44:21) [MSC v.1916 64 bit (AMD64)]
executable: [c:\Users\sagar\miniconda3\envs\sktime-dev\python.exe](file:///C:/Users/sagar/miniconda3/envs/sktime-dev/python.exe)
   machine: Windows-10-10.0.22621-SP0

Python dependencies:
          pip: 23.1.2
       sktime: 0.25.0
      sklearn: 1.2.2
       skbase: 0.4.6
        numpy: 1.24.3
        scipy: 1.10.1
       pandas: 2.0.3
   matplotlib: 3.7.2
       joblib: 1.3.1
        numba: 0.57.1
  statsmodels: 0.14.0
     pmdarima: 2.0.3
statsforecast: 1.5.0
      tsfresh: 0.20.1
      tslearn: 0.5.3.2
        torch: None
   tensorflow: 2.13.0
tensorflow_probability: None

Additional context

fkiraly commented 9 months ago

Hm, so that's a sklearn pipeline inside an sktime pipeline, right?

fkiraly commented 9 months ago

FYI @RNKuhns if you are still watching this - as you wrote the visual blocks interface.

fkiraly commented 5 months ago

resolved by https://github.com/sktime/skbase/pull/310