SalesforceAIResearch / uni2ts

[ICML2024] Unified Training of Universal Time Series Forecasting Transformers
Apache License 2.0
797 stars 81 forks source link

Get a error from get started code #3

Closed flhred closed 5 months ago

flhred commented 6 months ago

run the code from get started

import torch
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from huggingface_hub import hf_hub_download

from uni2ts.eval_util.plot import plot_single
from uni2ts.model.moirai import MoiraiForecast

SIZE = "small"  # model size: choose from {'small', 'base', 'large'}
PDT = 20  # prediction length: any positive integer
CTX = 200  # context length: any positive integer
PSZ = "auto"  # patch size: choose from {"auto", 8, 16, 32, 64, 128}
BSZ = 32  # batch size: any positive integer
TEST = 100  # test set length: any positive integer

# Read data into pandas DataFrame
url = (
    "https://gist.githubusercontent.com/rsnirwan/c8c8654a98350fadd229b00167174ec4"
    "/raw/a42101c7786d4bc7695228a0f2c8cea41340e18f/ts_wide.csv"
)
df = pd.read_csv(url, index_col=0, parse_dates=True)

# Convert into GluonTS dataset
ds = PandasDataset(dict(df))

# Split into train/test set
train, test_template = split(
    ds, offset=-TEST
)  # assign last TEST time steps as test set

# Construct rolling window evaluation
test_data = test_template.generate_instances(
    prediction_length=PDT,  # number of time steps for each prediction
    windows=TEST // PDT,  # number of windows in rolling window evaluation
    distance=PDT,  # number of time steps between each window - distance=PDT for non-overlapping windows
)

# Prepare pre-trained model by downloading model weights from huggingface hub
model = MoiraiForecast.load_from_checkpoint(
    checkpoint_path=hf_hub_download(
        repo_id=f"Salesforce/moirai-1.0-R-{SIZE}", filename="model.ckpt"
    ),
    prediction_length=PDT,
    context_length=CTX,
    patch_size=PSZ,
    num_samples=100,
    target_dim=1,
    feat_dynamic_real_dim=ds.num_feat_dynamic_real,
    past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
    map_location="cuda:0" if torch.cuda.is_available() else "cpu",
)

predictor = model.create_predictor(batch_size=BSZ)
forecasts = predictor.predict(test_data.input)

input_it = iter(test_data.input)
label_it = iter(test_data.label)
forecast_it = iter(forecasts)

inp = next(input_it)
label = next(label_it)
forecast = next(forecast_it)

plot_single(
    inp, 
    label, 
    forecast, 
    context_length=200,
    name="pred",
    show_label=True,
)

have a error Exception ignored in: <generator object PyTorchPredictor.predict at 0x7160c265e440> Traceback (most recent call last): File "/home/allen/miniconda3/lib/python3.12/site-packages/gluonts/torch/model/predictor.py", line 89, in predict File "/home/allen/miniconda3/lib/python3.12/site-packages/torch/autograd/grad_mode.py", line 84, in exit TypeError: 'NoneType' object is not callable

gorold commented 6 months ago

Hi! Could you provide the full stack trace and your environment?

flhred commented 6 months ago

Hello, Thanks for reply

(base) allen@allen-machine:~/Workspace/uni2ts$ python run.py
Exception ignored in: <generator object PyTorchPredictor.predict at 0x7bdb5d9ca440>
Traceback (most recent call last):
  File "/home/allen/miniconda3/lib/python3.12/site-packages/gluonts/torch/model/predictor.py", line 89, in predict
  File "/home/allen/miniconda3/lib/python3.12/site-packages/torch/autograd/grad_mode.py", line 84, in __exit__
TypeError: 'NoneType' object is not callable

this is the full stack trace

my environment is follow the

pip install -e '.[notebook]'

It shows

Successfully built salesforce-uni2ts
Installing collected packages: salesforce-uni2ts
  Attempting uninstall: salesforce-uni2ts
    Found existing installation: salesforce-uni2ts 0.0.0a0
    Uninstalling salesforce-uni2ts-0.0.0a0:
      Successfully uninstalled salesforce-uni2ts-0.0.0a0
Successfully installed salesforce-uni2ts-0.0.0a0
gorold commented 6 months ago

Could you try to run it in a Jupyter notebook? I'm not too sure whats going on here..

Could you show the output of pip freeze?

flhred commented 6 months ago
(base) allen@allen-machine:~/Workspace/uni2ts$ pip freeze
absl-py==2.1.0
aiohttp==3.9.3
aiosignal==1.3.1
annotated-types==0.6.0
antlr4-python3-runtime==4.9.3
anyio==4.3.0
appdirs==1.4.4
archspec @ file:///croot/archspec_1709217642129/work
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work
astunparse==1.6.3
async-lru==2.0.4
attrs==23.2.0
Babel==2.14.0
beautifulsoup4==4.12.3
bleach==6.1.0
boltons @ file:///work/perseverance-python-buildout/croot/boltons_1698851177130/work
Bottleneck @ file:///croot/bottleneck_1709069899917/work
Brotli @ file:///work/perseverance-python-buildout/croot/brotli-split_1698805593785/work
bs4==0.0.1
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1707022139797/work/certifi
cffi @ file:///croot/cffi_1700254295673/work
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
click==8.1.7
cloudpickle==3.0.0
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work
conda @ file:///croot/conda_1710772050586/work
conda-content-trust @ file:///work/perseverance-python-buildout/croot/conda-content-trust_1698882886606/work
conda-libmamba-solver @ file:///croot/conda-libmamba-solver_1702997573971/work/src
conda-package-handling @ file:///work/perseverance-python-buildout/croot/conda-package-handling_1698851267218/work
conda_package_streaming @ file:///work/perseverance-python-buildout/croot/conda-package-streaming_1698847176583/work
contourpy @ file:///work/perseverance-python-buildout/croot/contourpy_1701756524386/work
cryptography @ file:///croot/cryptography_1710350347627/work
cycler @ file:///home/conda/feedstock_root/build_artifacts/cycler_1696677705766/work
datasets==2.17.1
debugpy @ file:///work/perseverance-python-buildout/croot/debugpy_1698884710808/work
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
defusedxml==0.7.1
dill==0.3.8
distro @ file:///work/perseverance-python-buildout/croot/distro_1701732366176/work
einops==0.7.0
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1704921103267/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
Farama-Notifications==0.0.4
fastjsonschema==2.19.1
filelock==3.13.1
flatbuffers==24.3.7
fonttools==4.25.0
fqdn==1.5.1
frozendict==2.4.0
frozenlist==1.4.1
fsspec==2023.10.0
gast==0.5.4
gluonts==0.14.4
google-pasta==0.2.0
grpcio==1.62.1
gym-anytrading==2.0.0
gymnasium==0.29.1
h11==0.14.0
h5py==3.10.0
html5lib==1.1
httpcore==1.0.4
httpx==0.25.2
huggingface-hub==0.22.0
hydra-core==1.3.0
idna @ file:///work/perseverance-python-buildout/croot/idna_1698845632828/work
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1710971335535/work
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1708996548741/work
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1709559745751/work
ipywidgets==8.1.2
isoduration==20.11.0
jax==0.4.25
jaxlib==0.4.25
jaxtyping==0.2.28
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
Jinja2==3.1.3
json5==0.9.24
jsonpatch @ file:///croot/jsonpatch_1710807507480/work
jsonpointer==2.1
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.4
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1710255804825/work
jupyter_core @ file:///work/perseverance-python-buildout/croot/jupyter_core_1701731747496/work
jupyter_server==2.13.0
jupyter_server_terminals==0.5.3
jupyterlab==4.1.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.25.4
jupyterlab_widgets==3.0.10
keras==3.1.1
kiwisolver @ file:///work/perseverance-python-buildout/croot/kiwisolver_1698847502605/work
libclang==18.1.1
libmambapy @ file:///work/perseverance-python-buildout/croot/mamba-split_1701744133524/work/libmambapy
lightning==2.2.1
lightning-utilities==0.11.1
lxml==5.1.0
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib @ file:///work/perseverance-python-buildout/croot/matplotlib-suite_1698863180732/work
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
mdurl==0.1.2
menuinst @ file:///croot/menuinst_1706732933928/work
mistune==3.0.2
ml-dtypes==0.3.2
mootdx==0.11.4
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
multitasking==0.0.11
munkres==1.1.4
namex==0.0.7
nbclient==0.10.0
nbconvert==7.16.3
nbformat==5.10.3
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
networkx==3.2.1
notebook==7.1.2
notebook_shim==0.2.4
numexpr @ file:///work/perseverance-python-buildout/croot/numexpr_1698871031164/work
numpy @ file:///croot/numpy_and_numpy_base_1708638617955/work/dist/numpy-1.26.4-cp312-cp312-linux_x86_64.whl#sha256=b8e59fcb10ab071cbeded47aed247ddfa1c902d595848c27f82021340e226134
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.99
nvidia-nvtx-cu12==12.1.105
omegaconf==2.3.0
opt-einsum==3.3.0
optree==0.10.0
orjson==3.9.15
overrides==7.7.0
packaging @ file:///croot/packaging_1710807400464/work
pandas==2.1.4
pandas_ta==0.3.14b0
pandocfilters==1.5.1
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
patsy @ file:///home/conda/feedstock_root/build_artifacts/patsy_1704469236901/work
peewee==3.17.1
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
pillow @ file:///croot/pillow_1707233021655/work
platformdirs @ file:///work/perseverance-python-buildout/croot/platformdirs_1701732573265/work
pluggy @ file:///work/perseverance-python-buildout/croot/pluggy_1698805497733/work
prettytable==3.10.0
prometheus_client==0.20.0
prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1702399386289/work
protobuf==4.25.3
psutil @ file:///work/perseverance-python-buildout/croot/psutil_1698863411559/work
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
pyarrow==15.0.2
pyarrow-hotfix==0.6
pycosat @ file:///work/perseverance-python-buildout/croot/pycosat_1698863456259/work
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
pycryptodome==3.20.0
pydantic==2.6.4
pydantic_core==2.16.3
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1700607939962/work
pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1652235407899/work
PySocks @ file:///work/perseverance-python-buildout/croot/pysocks_1698845478203/work
pytdx==1.72
python-dateutil==2.9.0.post0
python-dotenv==1.0.0
python-json-logger==2.0.7
pytorch-lightning==2.2.1
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1706886791323/work
PyYAML==6.0.1
pyzmq @ file:///croot/pyzmq_1705605076900/work
qtconsole==5.5.1
QtPy==2.4.1
QuantStats==0.0.62
referencing==0.34.0
requests @ file:///croot/requests_1707355572290/work
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rpds-py==0.18.0
ruamel.yaml @ file:///work/perseverance-python-buildout/croot/ruamel.yaml_1698863605521/work
-e git+https://github.com/SalesforceAIResearch/uni2ts.git@8e07e899716c970787e9f2224e847c66c59d3eaf#egg=salesforce_uni2ts
scipy==1.11.4
seaborn @ file:///home/conda/feedstock_root/build_artifacts/seaborn-split_1706340836595/work
Send2Trash==1.8.2
setuptools==69.2.0
simplejson==3.19.2
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
sniffio==1.3.1
socksio==1.0.0
soupsieve==2.5
stable-baselines3==2.2.1
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
statsmodels @ file:///work/perseverance-python-buildout/croot/statsmodels_1701748323053/work
sympy==1.12
TA-Lib==0.4.28
tabulate==0.9.0
tdxpy==0.2.7
tenacity==8.2.3
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow==2.16.1
termcolor==2.4.0
terminado==0.18.1
tinycss2==1.2.1
toolz==0.12.1
torch==2.2.1
torchaudio==2.2.1
torchmetrics==1.3.2
torchvision==0.17.1
tornado @ file:///work/perseverance-python-buildout/croot/tornado_1698866362018/work
tqdm @ file:///work/perseverance-python-buildout/croot/tqdm_1701735729845/work
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1710254411456/work
truststore @ file:///work/perseverance-python-buildout/croot/truststore_1701735771625/work
tushare==1.4.5
typeguard==2.13.3
types-python-dateutil==2.9.0.20240316
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1708904622550/work
tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1707747584337/work
uri-template==1.3.0
urllib3 @ file:///croot/urllib3_1707770551213/work
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
webcolors==1.13
webencodings==0.5.1
websocket-client==0.57.0
Werkzeug==3.0.1
wheel==0.43.0
widgetsnbextension==4.0.10
wrapt==1.16.0
xxhash==3.4.1
yarl==1.9.4
yfinance==0.2.37
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1695255097490/work
zstandard @ file:///work/perseverance-python-buildout/croot/zstandard_1698847073368/work

This the moirai_forecast.ipynb running in jupyter notebook It said

ValueError: Failed to load .env file I find the .env file is an empty file

Screenshot from 2024-03-26 16-01-03 Screenshot from 2024-03-26 16-01-25

liu-jc commented 6 months ago

Hi @flhred , please create .env file by touch .env in the main folder.

flhred commented 6 months ago

Hello, Thanks reply @liu-jc

I had run the cmd follow the readme file, and it create a .env empty file

(base) allen@allen-machine:~/Workspace/uni2ts$ touch .env
(base) allen@allen-machine:~/Workspace/uni2ts$ 
liu-jc commented 6 months ago

Hi @flhred , I presume that the problem has been solved by adding .env file. So, I'm closing the issue. Feel free to reopen it if you have further questions.

gorold commented 6 months ago

Let's keep this issue open for now, the problem was not caused by the .env file, but rather running the get started code from a python script instead of notebook. See the discussion in #2.

flhred commented 6 months ago

Hi, use the get started code in jupyter, It's good Thanks Screenshot from 2024-03-27 16-10-54

rainbownmm commented 6 months ago

+1 Thanks for the great work, I'm getting the same error and it's running fine in jupyter but getting this error in a normal environment

abdulfatir commented 6 months ago

FYI, this looks like a warning about "ignored exception" and not an Error in itself. This may be relevant: https://stackoverflow.com/questions/59756501/how-to-stop-python-from-printing-out-ignored-exceptions-from-other-libraries

gorold commented 5 months ago

The code was running successfully, but as Fatir mentioned, python was printing out ignored exceptions. Added plt.show() to display the plot for users running it in a script.

fritol commented 5 months ago

its working now but i did not change anything just a restart but the chart has an odd pallete image

fritol commented 5 months ago

its working now but i did not change anything just a restart but the chart has an odd pallete image

modified plot_single() with different colormaps cmap: str = 'viridis' # Add a new parameter for the color map

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from gluonts import maybe
from gluonts.model import Forecast
import matplotlib.cm as cm

def plot_single(
    inp: dict,
    label: dict,
    forecast: Forecast,
    context_length: int,
    intervals: tuple[float, ...] = (0.5, 0.9),
    ax: Optional[plt.axis] = None,
    dim: Optional[int] = None,
    name: Optional[str] = None,
    show_label: bool = False,
    cmap: str = 'viridis' # Add a new parameter for the color map
):
    ax = maybe.unwrap_or_else(ax, plt.gca)
    print('A plotsingle')

    target = np.concatenate([inp["target"], label["target"]], axis=-1)
    start = inp["start"]
    if dim is not None:
        target = target[dim]
        forecast = forecast.copy_dim(dim)
    print('B plotsingle')
    index = pd.period_range(start, periods=len(target), freq=start.freq)
    ax.plot(
        index.to_timestamp()[-context_length - forecast.prediction_length :],
        target[-context_length - forecast.prediction_length :],
        label="target",
        color=cm.get_cmap(cmap)(0.5), # Use the color map for the plot color
    )
    print('C plotsingle')
    forecast.plot(
        intervals=intervals,
        ax=ax,
        color=cm.get_cmap(cmap)(0.75), # Use the color map for the forecast color
        name=name,
        show_label=show_label,
    )
    ax.set_xticks(ax.get_xticks())
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
    ax.legend(loc="lower left")
Honghe commented 1 month ago

run the code from get started

import torch
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from huggingface_hub import hf_hub_download

from uni2ts.eval_util.plot import plot_single
from uni2ts.model.moirai import MoiraiForecast

SIZE = "small"  # model size: choose from {'small', 'base', 'large'}
PDT = 20  # prediction length: any positive integer
CTX = 200  # context length: any positive integer
PSZ = "auto"  # patch size: choose from {"auto", 8, 16, 32, 64, 128}
BSZ = 32  # batch size: any positive integer
TEST = 100  # test set length: any positive integer

# Read data into pandas DataFrame
url = (
    "https://gist.githubusercontent.com/rsnirwan/c8c8654a98350fadd229b00167174ec4"
    "/raw/a42101c7786d4bc7695228a0f2c8cea41340e18f/ts_wide.csv"
)
df = pd.read_csv(url, index_col=0, parse_dates=True)

# Convert into GluonTS dataset
ds = PandasDataset(dict(df))

# Split into train/test set
train, test_template = split(
    ds, offset=-TEST
)  # assign last TEST time steps as test set

# Construct rolling window evaluation
test_data = test_template.generate_instances(
    prediction_length=PDT,  # number of time steps for each prediction
    windows=TEST // PDT,  # number of windows in rolling window evaluation
    distance=PDT,  # number of time steps between each window - distance=PDT for non-overlapping windows
)

# Prepare pre-trained model by downloading model weights from huggingface hub
model = MoiraiForecast.load_from_checkpoint(
    checkpoint_path=hf_hub_download(
        repo_id=f"Salesforce/moirai-1.0-R-{SIZE}", filename="model.ckpt"
    ),
    prediction_length=PDT,
    context_length=CTX,
    patch_size=PSZ,
    num_samples=100,
    target_dim=1,
    feat_dynamic_real_dim=ds.num_feat_dynamic_real,
    past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
    map_location="cuda:0" if torch.cuda.is_available() else "cpu",
)

predictor = model.create_predictor(batch_size=BSZ)
forecasts = predictor.predict(test_data.input)

input_it = iter(test_data.input)
label_it = iter(test_data.label)
forecast_it = iter(forecasts)

inp = next(input_it)
label = next(label_it)
forecast = next(forecast_it)

plot_single(
    inp, 
    label, 
    forecast, 
    context_length=200,
    name="pred",
    show_label=True,
)

have a error Exception ignored in: <generator object PyTorchPredictor.predict at 0x7160c265e440> Traceback (most recent call last): File "/home/allen/miniconda3/lib/python3.12/site-packages/gluonts/torch/model/predictor.py", line 89, in predict File "/home/allen/miniconda3/lib/python3.12/site-packages/torch/autograd/grad_mode.py", line 84, in exit TypeError: 'NoneType' object is not callable

It seems the problem 'NoneType' object is not callable is the nested generator, this patch works:

@@ -54,3 +54,3 @@
 label_it = iter(test_data.label)
-forecast_it = iter(forecasts)
+# forecast_it = iter(forecasts)

@@ -58,3 +58,5 @@
 label = next(label_it)
-forecast = next(forecast_it)
+forecast = next(forecasts)
+
+del forecasts