bis-med-it / gingado

A machine learning library for economics and finance
https://bis-med-it.github.io/gingado/
Apache License 2.0
12 stars 4 forks source link

`gingado.utils.Lag` does not fit when `keep_contemporaneous_X=True` #27

Open dkgaraujo opened 1 week ago

dkgaraujo commented 1 week ago
import numpy as np
from gingado.utils import Lag
from sklearn.pipeline import Pipeline

X = np.random.rand(15, 2)
y = np.random.rand(15)

lags = 3
jump = 2

pipe = Pipeline([('lagger', Lag(lags=lags, jump=jump, keep_contemporaneous_X=False))]).fit_transform(X, y)
# the above works well, but:
pipe = Pipeline([('lagger', Lag(lags=lags, jump=jump, keep_contemporaneous_X=True))]).fit_transform(X, y)
"""
  Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
    File "/usr/local/lib64/python3.11/site-packages/sklearn/base.py", line 1473, in wrapper
      return fit_method(estimator, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/usr/local/lib64/python3.11/site-packages/sklearn/pipeline.py", line 544, in fit_transform
      return last_step.fit_transform(
             ^^^^^^^^^^^^^^^^^^^^^^^^
    File "/usr/local/lib64/python3.11/site-packages/sklearn/utils/_set_output.py", line 313, in wrapped
      data_to_wrap = f(self, X, *args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/usr/local/lib64/python3.11/site-packages/sklearn/base.py", line 1101, in fit_transform
      return self.fit(X, y, **fit_params).transform(X)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/usr/local/lib64/python3.11/site-packages/sklearn/utils/_set_output.py", line 313, in wrapped
      data_to_wrap = f(self, X, *args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/douglas-araujo/.local/lib/python3.11/site-packages/gingado/utils.py", line 90, in transform
      X_colnames = list(self.feature_names_in_) if self.keep_contemporaneous_X else []
                        ^^^^^^^^^^^^^^^^^^^^^^
  AttributeError: 'Lag' object has no attribute 'feature_names_in_'
"""