tinkoff-ai / etna

ETNA – Time-Series Library
https://etna.tinkoff.ru
Apache License 2.0
856 stars 81 forks source link

Add sorting by timestamp before the fit in catboost models #1337

Closed Mr-Geekman closed 1 year ago

Mr-Geekman commented 1 year ago

Before submitting (must do checklist)

Proposed Changes

Add sorting by timestamp before the fit in catboost models.

Closing issues

Closes #792.

Mr-Geekman commented 1 year ago

It was checked that sorting doesn't influence the result of predict. So, it is done only in fit.

github-actions[bot] commented 1 year ago

🚀 Deployed on https://deploy-preview-1337--etna-docs.netlify.app

Mr-Geekman commented 1 year ago

The results of running with and without sorting in fit.

default:

default + has_time=True

sort(train):

sort(train) + has_time=True:

Script:

import pandas as pd

from etna.models import CatBoostMultiSegmentModel
from etna.datasets import TSDataset
from etna.transforms import LagTransform, SegmentEncoderTransform, DateFlagsTransform
from etna.pipeline import Pipeline
from etna.metrics import SMAPE, MAE

HORIZON = 14

def main():
    df = pd.read_csv("examples/data/example_dataset.csv")
    df_wide = TSDataset.to_dataset(df)
    ts = TSDataset(df=df_wide, freq="D")

    model = CatBoostMultiSegmentModel(has_time=True)
    transforms = [
        LagTransform(in_column="target", lags=list(range(HORIZON, 50)), out_column="lags"),
        SegmentEncoderTransform(),
        DateFlagsTransform(),
    ]
    pipeline = Pipeline(model=model, transforms=transforms, horizon=HORIZON)

    metrics, _, _ = pipeline.backtest(ts=ts, metrics=[SMAPE(), MAE()], n_folds=5)

    print(metrics.mean())

if __name__ == "__main__":
    main()
codecov-commenter commented 1 year ago

Codecov Report

Merging #1337 (b7af7d3) into master (aac0fe1) will increase coverage by 0.13%. The diff coverage is 100.00%.

:exclamation: Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

@@            Coverage Diff             @@
##           master    #1337      +/-   ##
==========================================
+ Coverage   88.95%   89.09%   +0.13%     
==========================================
  Files         204      204              
  Lines       12641    12636       -5     
==========================================
+ Hits        11245    11258      +13     
+ Misses       1396     1378      -18     
Files Changed Coverage Δ
etna/models/catboost.py 100.00% <100.00%> (ø)

... and 8 files with indirect coverage changes

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more