sktime / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.99k stars 631 forks source link

[BUG] `test_none_reduction` fails on main #1614

Open fkiraly opened 2 months ago

fkiraly commented 2 months ago

The test_none_reduction test fails on main, with

>                   reduced = torch.cat([global_state, local_state])
E                   RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated

This seems like change in upcasting, and is resolved in the case without distributed computation by setting dist_reduce_fx="mean" instead of "cat", which also seems to produce the correct result (but not conclusively since we need to deal with the distributed case)

fkiraly commented 2 months ago

This is the pip list

Package                                 Version
--------------------------------------- -----------
absl-py                                 2.1.0
accessible-pygments                     0.0.5
aiohappyeyeballs                        2.4.0
aiohttp                                 3.10.5
aiosignal                               1.3.1
alabaster                               1.0.0
alembic                                 1.13.2
appnope                                 0.1.4
astroid                                 3.2.4
asttokens                               2.4.1
async-timeout                           4.0.3
attrs                                   24.2.0
babel                                   2.16.0
backports.functools-lru-cache           2.0.0
beautifulsoup4                          4.12.3
black                                   24.8.0
bleach                                  6.1.0
certifi                                 2024.7.4
cfgv                                    3.4.0
charset-normalizer                      3.3.2
click                                   8.1.7
cmaes                                   0.11.1
colorlog                                6.8.2
comm                                    0.2.2
commonmark                              0.9.1
contourpy                               1.2.1
coverage                                7.6.1
cpflows                                 0.1.2
cycler                                  0.12.1
debugpy                                 1.8.5
decorator                               5.1.1
defusedxml                              0.7.1
dill                                    0.3.8
distlib                                 0.3.8
docutils                                0.21.2
exceptiongroup                          1.2.2
execnet                                 2.1.1
executing                               2.0.1
fastjsonschema                          2.20.0
filelock                                3.15.4
flake8                                  7.1.1
fonttools                               4.53.1
frozenlist                              1.4.1
fsspec                                  2024.6.1
future                                  1.0.0
greenlet                                3.0.3
grpcio                                  1.66.0
h5py                                    3.11.0
identify                                2.6.0
idna                                    3.8
imagesize                               1.4.1
iniconfig                               2.0.0
invoke                                  2.2.0
ipykernel                               6.29.5
ipython                                 8.26.0
ipywidgets                              8.1.5
isort                                   5.13.2
jedi                                    0.19.1
Jinja2                                  3.1.4
joblib                                  1.4.2
jsonschema                              4.23.0
jsonschema-specifications               2023.12.1
jupyter_client                          8.6.2
jupyter_core                            5.7.2
jupyterlab_pygments                     0.3.0
jupyterlab_widgets                      3.0.13
kiwisolver                              1.4.5
lightning                               2.4.0
lightning-utilities                     0.11.6
Mako                                    1.3.5
Markdown                                3.7
MarkupSafe                              2.1.5
matplotlib                              3.9.2
matplotlib-inline                       0.1.7
mccabe                                  0.7.0
mistune                                 3.0.2
mpmath                                  1.3.0
multidict                               6.0.5
mypy                                    1.11.2
mypy-extensions                         1.0.0
nbclient                                0.10.0
nbconvert                               7.16.4
nbformat                                5.10.4
nbsphinx                                0.9.5
nest-asyncio                            1.6.0
networkx                                3.3
nodeenv                                 1.9.1
numpy                                   1.26.4
optuna                                  3.2.0
packaging                               24.1
pandas                                  2.2.2
pandoc                                  2.4
pandocfilters                           1.5.1
parso                                   0.8.4
pathspec                                0.12.1
patsy                                   0.5.6
pexpect                                 4.9.0
pillow                                  10.4.0
pip                                     24.2
platformdirs                            4.2.2
pluggy                                  1.5.0
plumbum                                 1.8.3
ply                                     3.11
pre-commit                              3.8.0
prompt_toolkit                          3.0.47
protobuf                                5.27.3
psutil                                  6.0.0
ptyprocess                              0.7.0
pure_eval                               0.2.3
pyarrow                                 17.0.0
pycodestyle                             2.12.1
pydata-sphinx-theme                     0.15.4
pydocstyle                              6.3.0
pyflakes                                3.2.0
Pygments                                2.18.0
pylint                                  3.2.6
pyparsing                               3.1.3
pytest                                  8.3.2
pytest-cov                              5.0.0
pytest-dotenv                           0.5.2
pytest-github-actions-annotate-failures 0.2.0
pytest-sugar                            1.0.0
pytest-xdist                            3.6.1
python-dateutil                         2.9.0.post0
python-dotenv                           1.0.1
pytorch_forecasting                     1.0.0
pytorch-lightning                       2.4.0
pytorch_optimizer                       2.12.0
pytz                                    2024.1
PyYAML                                  6.0.2
pyzmq                                   26.2.0
recommonmark                            0.7.1
referencing                             0.35.1
requests                                2.32.3
rpds-py                                 0.20.0
scikit-learn                            1.5.1
scipy                                   1.14.1
seaborn                                 0.13.2
setuptools                              65.5.0
six                                     1.16.0
snowballstemmer                         2.2.0
soupsieve                               2.6
Sphinx                                  8.0.2
sphinxcontrib-applehelp                 2.0.0
sphinxcontrib-devhelp                   2.0.0
sphinxcontrib-htmlhelp                  2.1.0
sphinxcontrib-jsmath                    1.0.1
sphinxcontrib-qthelp                    2.0.0
sphinxcontrib-serializinghtml           2.0.0
SQLAlchemy                              2.0.32
stack-data                              0.6.3
statsmodels                             0.14.2
subprocess32                            3.5.4
sympy                                   1.13.2
tensorboard                             2.17.1
tensorboard-data-server                 0.7.2
termcolor                               2.4.0
threadpoolctl                           3.5.0
tinycss2                                1.3.0
tomli                                   2.0.1
tomlkit                                 0.13.2
torch                                   2.2.2
torchmetrics                            1.4.1
torchvision                             0.17.2
tornado                                 6.4.1
tqdm                                    4.66.5
traitlets                               5.14.3
typing_extensions                       4.12.2
tzdata                                  2024.1
urllib3                                 2.2.2
virtualenv                              20.26.3
wcwidth                                 0.2.13
webencodings                            0.5.1
Werkzeug                                3.0.4
widgetsnbextension                      4.0.13
yarl                                    1.9.4
fkiraly commented 2 months ago

Did not fail on this run https://github.com/jdb78/pytorch-forecasting/actions/runs/10545588797

list of dependencies was this:

  - Installing attrs (23.1.0)
  - Installing rpds-py (0.10.2)
  - Installing referencing (0.30.2)
  - Installing frozenlist (1.4.0)
  - Installing idna (3.7)
  - Installing jsonschema-specifications (2023.7.1)
  - Installing markupsafe (2.1.3)
  - Installing mpmath (1.3.0)
  - Installing multidict (6.0.4)
  - Installing platformdirs (3.10.0)
  - Installing six (1.16.0)
  - Installing traitlets (5.9.0)
  - Installing typing-extensions (4.8.0)
  - Installing aiosignal (1.3.1)
  - Installing async-timeout (4.0.3)
  - Installing certifi (2024.7.4)
  - Installing charset-normalizer (3.2.0)
  - Installing fastjsonschema (2.18.0)
  - Installing filelock (3.12.3)
  - Installing fsspec (2023.9.0)
  - Installing greenlet (2.0.2)
  - Installing jinja2 (3.1.4)
  - Installing jsonschema (4.19.0)
  - Installing jupyter-core (5.3.1)
  - Installing networkx (3.1)
  - Installing numpy (1.24.4)
  - Installing packaging (23.1)
  - Installing pyasn1 (0.5.0)
  - Installing python-dateutil (2.8.2)
  - Installing pyzmq (25.1.1)
  - Downgrading setuptools (70.1.0 -> 70.0.0)
  - Installing sympy (1.12)
  - Installing tornado (6.4.1)
  - Installing urllib3 (1.26.19)
  - Installing yarl (1.9.2)
  - Installing aiohttp (3.9.4)
  - Installing asttokens (2.4.0)
  - Installing cachetools (5.3.1)
  - Installing contourpy (1.1.0)
  - Installing cycler (0.11.0)
  - Installing pathspec (0.11.2)
  - Installing patsy (0.5.4)
  - Installing plumbum (1.8.2)
  - Installing ply (3.11)
  - Installing protobuf (4.24.3)
  - Installing psutil (5.9.5)
  - Installing pycodestyle (2.9.1)
  - Installing pydantic (2.4.0)
  - Installing pyflakes (2.5.0)
  - Installing pytest (8.1.1)
  - Installing python-dotenv (1.0.0)
  - Installing pytorch-lightning (2.0.8)
  - Installing scikit-learn (1.3.2)
  - Installing seaborn (0.12.2)
  - Installing sphinx (7.1.2)
  - Installing starlette (0.36.3)
  - Installing subprocess32 (3.5.4)
  - Installing tensorboard-data-server (0.7.1)
  - Installing termcolor (2.3.0)
  - Installing tokenize-rt (5.2.0)
  - Installing tomlkit (0.12.1)
  - Installing torchvision (0.17.1)
  - Installing virtualenv (20.24.5)
  - Installing werkzeug (2.3.7)
  - Installing wheel (0.41.2)
  - Installing widgetsnbextension (4.0.10)
  - Installing black (24.3.0)
  - Installing cpflows (0.1.2)
  - Installing fastapi (0.110.0)
  - Installing flake8 (5.0.4)
  - Installing invoke (2.2.0)
  - Installing ipykernel (6.29.3)
  - Installing ipywidgets (8.1.2)
  - Installing lightning (2.3.2)
  - Installing mypy (1.9.0)
  - Installing nbsphinx (0.9.3)
  - Installing optuna-integration (3.6.0)
  - Installing pandoc (2.3)
  - Installing pre-commit (3.5.0)
  - Installing pyarrow (16.1.0)
  - Installing pydata-sphinx-theme (0.14.4)
  - Installing pydocstyle (6.3.0)
  - Installing pylint (3.1.0)
  - Installing pytest-cov (4.1.0)
  - Installing pytest-dotenv (0.5.2)
  - Installing pytest-github-actions-annotate-failures (0.2.0)
  - Installing pytest-sugar (1.0.0)
  - Installing pytest-xdist (3.5.0)
  - Installing pytorch-optimizer (2.12.0)
  - Installing recommonmark (0.7.1)
  - Installing statsmodels (0.14.1)
  - Installing tensorboard (2.14.0)
benHeid commented 2 months ago

I suppose I have found the bug. I suppose it is a wrong initialization of the default's of the states of the metric. However, a change to an empty list results in further issues since it tried to calculate the mean of the lengths. Is this the intended behavior or should it be a concatenate too?

fkiraly commented 2 months ago

I am also unsure, since there are no tests covering the distributed case - it should always compute a single number, and it must do so in the distributed case too

I did not know how to test that, and it currently is not afaik, that is where I got stuck

benHeid commented 2 months ago

Are you sure that it should return always a single number? If you take a look at the failing test, I interpret it that it should return the same shape as the original input. Thus, not a single number.

fkiraly commented 2 months ago

Ah, apologies. I was omitting part of the reasoning. What I meant is that two conditions have to be satisfied for MultiHorizonMetric-s such as MAE:

This must be true for all possible settings with respect to distributed computing, i.e., a single job, or a distibuted job where dist_reduce_fx settings matter.

So there are four conditions that need to run, but only two are tested, because we only test the non-distributed case.

I have described a hacky fix that fixes the reduction="none" case in the non-distributed case, but I do not know how to test or fix it in the distributed case.

fnhirwa commented 5 days ago

As indicated in torchmetrics docs, cat reduction only makes sense when the states are list instances. https://lightning.ai/docs/torchmetrics/stable/references/metric.html#torchmetrics.Metric.add_state.

I think we should consider the use case of using cat reduction.