Nixtla / neuralforecast

Scalable and user friendly neural :brain: forecasting algorithms.
https://nixtlaverse.nixtla.io/neuralforecast
Apache License 2.0
2.69k stars 312 forks source link

[core] Can't use DBFS as a filesystem in distributed #1045

Open piUek opened 1 week ago

piUek commented 1 week ago

What happened + What you expected to happen

I'm trying to run https://nixtlaverse.nixtla.io/neuralforecast/examples/distributed/distributed_neuralforecast.html sample on databricks. As a storage for partitions I'm using dbfs. My first issue was that I can't pass additional arguments needed for dbfs to work (instance and token) which i've worked around by:

from fsspec.implementations.dbfs import DatabricksFileSystem
from fsspec.registry import register_implementation, known_implementations

class CustomDatabricksFileSystem(DatabricksFileSystem):
    def __init__(self, *args, **kwargs):
        kwargs['instance'] = ''
        kwargs['token'] = ''
        super().__init__(*args, **kwargs)

register_implementation('dbfs', CustomDatabricksFileSystem)

Then the second issue was from the fsspec.ls which in case of dbfs returns a list of dicts:

fs.ls(dist_cfg.partitions_path)
# [{'name': '/sop.tmp_partitions/_committed_168021060626075119',
#   'type': 'file',
#   'size': 421},
#  {'name': '/sop.tmp_partitions/_committed_6601671783478124078',
#   'type': 'file',
#   'size': 224}...
# ]

And so I get the error:

AttributeError: 'dict' object has no attribute 'endswith'
File <command-4284095506577098>, line 9
      1 nf = NeuralForecast(
      2     models=[
      3         NHITS(h=24, input_size=48, max_steps=2_000, **exogs, **distributed_kwargs),
   (...)
      7     freq=1,
      8 )
----> 9 nf.fit(spark_train, static_df=spark_static, distributed_config=dist_cfg, val_size=24)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-37c23ccc-4284-4774-8441-26032f968357/lib/python3.10/site-packages/neuralforecast/core.py:432, in <listcomp>(.0)
    427     if isinstance(protocol, tuple):
    428         protocol = protocol[0]
    429     files = [
    430         f"{protocol}://{file}"
    431         for file in fs.ls(distributed_config.partitions_path)
--> 432         if file.endswith("parquet")
    433     ]
    434     self.dataset = _FilesDataset(
    435         files=files,
    436         temporal_cols=temporal_cols,
   (...)
    441         min_size=df.groupBy(id_col).count().agg({"count": "min"}).first()[0],
    442     )
    443 elif df is None:

Versions / Dependencies

Click ``` Python - 3.10.12 system='Linux', release='5.15.0-1061-azure', version='#70~20.04.1-Ubuntu SMP Mon Apr 8 15:38:58 UTC 2024', machine='x86_64' databricks runtime = '14.3 LTS (includes Apache Spark 3.5.0, Scala 2.12)' adagio==0.2.4 aiohttp==3.9.5 aiosignal==1.3.1 alembic==1.13.2 anyio==3.5.0 appdirs==1.4.4 argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 asttokens==2.0.5 async-timeout==4.0.3 attrs==22.1.0 backcall==0.2.0 beautifulsoup4==4.11.1 black==22.6.0 bleach==4.1.0 blinker==1.4 boto3==1.24.28 botocore==1.27.96 certifi==2022.12.7 cffi==1.15.1 chardet==4.0.0 charset-normalizer==2.0.4 click==8.0.4 cloudpickle==3.0.0 colorlog==6.8.2 comm==0.1.2 contourpy==1.0.5 coreforecast==0.0.10 cryptography==39.0.1 cycler==0.11.0 Cython==0.29.32 databricks-sdk==0.1.6 dbus-python==1.2.18 debugpy==1.6.7 decorator==5.1.1 defusedxml==0.7.1 distlib==0.3.7 distro==1.7.0 distro-info==1.1+ubuntu0.2 docstring-to-markdown==0.11 entrypoints==0.4 executing==0.8.3 facets-overview==1.1.1 fastjsonschema==2.19.1 filelock==3.13.1 fonttools==4.25.0 frozenlist==1.4.1 fs==2.4.16 fsspec==2024.6.0 fugue==0.9.1 googleapis-common-protos==1.62.0 greenlet==3.0.3 grpcio==1.48.2 grpcio-status==1.48.1 httplib2==0.20.2 idna==3.4 importlib-metadata==4.6.4 ipykernel==6.25.0 ipython==8.14.0 ipython-genutils==0.2.0 ipywidgets==7.7.2 jedi==0.18.1 jeepney==0.7.1 Jinja2==3.1.2 jmespath==0.10.0 joblib==1.2.0 jsonschema==4.17.3 jupyter-client==7.3.4 jupyter-server==1.23.4 jupyter_core==5.2.0 jupyterlab-pygments==0.1.2 jupyterlab-widgets==1.0.0 keyring==23.5.0 kiwisolver==1.4.4 launchpadlib==1.10.16 lazr.restfulclient==0.14.4 lazr.uri==1.0.6 lightning-utilities==0.11.3.post0 llvmlite==0.43.0 lxml==4.9.1 Mako==1.3.5 MarkupSafe==2.1.1 matplotlib==3.7.0 matplotlib-inline==0.1.6 mccabe==0.7.0 mistune==0.8.4 more-itertools==8.10.0 mpmath==1.3.0 msgpack==1.0.8 multidict==6.0.5 mypy-extensions==0.4.3 nbclassic==0.5.2 nbclient==0.5.13 nbconvert==6.5.4 nbformat==5.7.0 nest-asyncio==1.5.6 networkx==3.3 neuralforecast==1.7.2 nodeenv==1.8.0 notebook==6.5.2 notebook_shim==0.2.2 numba==0.60.0 numpy==1.23.5 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.5.40 nvidia-nvtx-cu12==12.1.105 oauthlib==3.2.0 optuna==3.6.1 packaging==23.2 pandas==1.5.3 pandocfilters==1.5.0 parso==0.8.3 pathspec==0.10.3 patsy==0.5.3 pexpect==4.8.0 pickleshare==0.7.5 Pillow==9.4.0 platformdirs==2.5.2 plotly==5.9.0 pluggy==1.0.0 prometheus-client==0.14.1 prompt-toolkit==3.0.36 protobuf==4.24.0 psutil==5.9.0 psycopg2==2.9.3 ptyprocess==0.7.0 pure-eval==0.2.2 pyarrow==8.0.0 pyarrow-hotfix==0.5 pycparser==2.21 pydantic==1.10.6 pyflakes==3.1.0 Pygments==2.11.2 PyGObject==3.42.1 PyJWT==2.3.0 pyodbc==4.0.32 pyparsing==3.0.9 pyright==1.1.294 pyrsistent==0.18.0 python-apt==2.4.0+ubuntu3 python-dateutil==2.8.2 python-lsp-jsonrpc==1.1.1 python-lsp-server==1.8.0 pytoolconfig==1.2.5 pytorch-lightning==2.3.0 pytz==2022.7 PyYAML==6.0.1 pyzmq==23.2.0 ray==2.30.0 requests==2.28.1 rope==1.7.0 s3transfer==0.6.2 scikit-learn==1.1.1 scipy==1.10.0 seaborn==0.12.2 SecretStorage==3.3.1 Send2Trash==1.8.0 six==1.16.0 sniffio==1.2.0 soupsieve==2.3.2.post1 SQLAlchemy==2.0.31 ssh-import-id==5.11 stack-data==0.2.0 statsforecast==1.7.5 statsmodels==0.13.5 sympy==1.12.1 tenacity==8.1.0 tensorboardX==2.6.2.2 terminado==0.17.1 threadpoolctl==2.2.0 tinycss2==1.2.1 tokenize-rt==4.2.1 tomli==2.0.1 torch==2.3.1 torchmetrics==1.4.0.post0 tornado==6.1 tqdm==4.66.4 traitlets==5.7.1 triad==0.9.7 triton==2.3.1 typing_extensions==4.12.2 ujson==5.4.0 unattended-upgrades==0.1 urllib3==1.26.14 utilsforecast==0.1.11 virtualenv==20.16.7 wadllib==1.3.6 wcwidth==0.2.5 webencodings==0.5.1 websocket-client==0.58.0 whatthepatch==1.0.2 widgetsnbextension==3.6.1 yapf==0.33.0 yarl==1.9.4 zipp==1.0.0 ```

Reproduction script

import logging
import os

import numpy as np
import pandas as pd

from neuralforecast import NeuralForecast, DistributedConfig
from neuralforecast.auto import AutoNHITS
from neuralforecast.models import NHITS, LSTM
from utilsforecast.evaluation import evaluate
from utilsforecast.losses import mae, rmse, smape
from utilsforecast.plotting import plot_series

from fsspec.implementations.dbfs import DatabricksFileSystem
from fsspec.registry import register_implementation, known_implementations

class CustomDatabricksFileSystem(DatabricksFileSystem):
    def __init__(self, *args, **kwargs):
        kwargs['instance'] = ''
        kwargs['token'] = ''
        super().__init__(*args, **kwargs)

register_implementation('dbfs', CustomDatabricksFileSystem)

# Configuration required for distributed training
dist_cfg = DistributedConfig(
    partitions_path='dbfs:///tmp_partitions',  # path where the partitions will be saved
    num_nodes=2,  # number of nodes to use during training (machines)
    devices=1,   # number of GPUs in each machine
)

# pytorch lightning configuration
# the executors don't have permission to write on the filesystem, so we disable saving artifacts
distributed_kwargs = dict(
    accelerator='gpu',
    enable_progress_bar=False,
    logger=False,
    enable_checkpointing=False,
)

# exogenous features
exogs = {
    'futr_exog_list': ['exog_0'],
    'stat_exog_list': ['stat_0'],
}

# for the AutoNHITS
def config(trial):
    return dict(
        input_size=48,
        max_steps=2_000,
        learning_rate=trial.suggest_float('learning_rate', 1e-4, 1e-1, log=True),
        **exogs,
        **distributed_kwargs
    )

nf = NeuralForecast(
    models=[
        NHITS(h=24, input_size=48, max_steps=2_000, **exogs, **distributed_kwargs),
        AutoNHITS(h=24, config=config, backend='optuna', num_samples=2, alias='tuned_nhits'),
        LSTM(h=24, input_size=48, max_steps=2_000, **exogs, **distributed_kwargs),
    ],
    freq=1,
)
nf.fit(spark_train, static_df=spark_static, distributed_config=dist_cfg, val_size=24)

Issue Severity

High: It blocks me from completing my task.

jmoralez commented 1 week ago

Hey. Is it possible for you to use a remote storage like S3? DBFS is weird in the sense that you access it differently from spark than from pandas and in the training stage we read the partitions with pandas, so even if you manage to use it it'll break there.

piUek commented 1 week ago

Ok, I understand. I'm on Azure, so I'll try with ADLS

jmoralez commented 1 week ago

If you have experience with DBFS we could also give it a shot, I got stuck trying to define a path that could be written by spark and then retrieved by fsspec such that pandas would understand it.

jmoralez commented 1 week ago

Oh by the way, we recently fixed a bug in the distributed implementation which hasn't been released, you'll see only one executor training and the others will be idle. We'll make a release soon.

piUek commented 6 days ago

If you have experience with DBFS we could also give it a shot, I got stuck trying to define a path that could be written by spark and then retrieved by fsspec such that pandas would understand it.

I think I will try - it might be very convenient to use with dbfs for databricks users. Thank you for improving distributed implementation, I'll let you know about the results after the release

jmoralez commented 3 days ago

We just released 1.7.3 with the distributed fix.

piUek commented 3 days ago

Ok, it works with adls! Is it possible to use autoLSTM with distributed config?

jmoralez commented 2 days ago

It should be. Note that what will be distributed is the training of each model, so the search will be sequential. If you want to distribute the search instead you can try setting up ray on databricks, once you've done that ray should be able to distribute the trials on the cluster using the regular interface (no spark dataframes).