Closed AlexMV12 closed 2 months ago
Hi @AlexMV12 . Can you provide your code or the error, please?
Hello @cchallu, absolutely.
Actually, after a bit of work, I managed to make it work with all the other datasets present in the paper. I will give you the code, so that you can choose to integrate it in the repo and to make it available to everyone looking into reproducing NHiTS's results.
Note that the following code works:
LongHorizonInfo
class instead of the LongHorizonInfo2
class.With this in mind, I modified the expriments/long_horizon/run_nhits.py
script as follow (the first 60 rows):
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import argparse
import pandas as pd
from ray import tune
from neuralforecast.auto import AutoNHITS
from neuralforecast.core import NeuralForecast
from neuralforecast.losses.pytorch import MAE, HuberLoss
from neuralforecast.losses.numpy import mae, mse
from datasetsforecast.long_horizon import LongHorizon, LongHorizonInfo
from datasetsforecast.long_horizon2 import LongHorizon2, LongHorizon2Info
import logging
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)
if __name__ == '__main__':
# Parse execution parameters
verbose = True
parser = argparse.ArgumentParser()
parser.add_argument("-horizon", "--horizon", type=int)
parser.add_argument("-dataset", "--dataset", type=str)
parser.add_argument("-num_samples", "--num_samples", default=5, type=int)
args = parser.parse_args()
horizon = args.horizon
dataset = args.dataset
num_samples = args.num_samples
# Load dataset
if dataset in ['ETTm1', 'ETTm2', 'ETTh1', 'ETTh2']:
Y_df = LongHorizon2.load(directory='./data/', group=dataset)
else:
Y_df, _, _ = LongHorizon.load(directory='./data/', group=dataset)
Y_df['ds'] = pd.to_datetime(Y_df['ds'])
if dataset != 'ILI':
assert horizon in [96, 192, 336, 720]
else:
assert horizon in [24, 36, 48, 60]
freq = LongHorizonInfo[dataset].freq
n_time = len(Y_df.ds.unique())
if dataset in ['ETTm1', 'ETTm2', 'ETTh1', 'ETTh2']:
val_size = LongHorizon2Info[dataset].val_size
test_size = LongHorizon2Info[dataset].test_size
else:
val_size = int(.2 * n_time)
test_size = int(.2 * n_time)
...
This works on the ILI
dataset and the result looks good. I will try to run the rest of the dataset, but now everything should work.
Awesome, thanks for sharing the code!
Description
The
experiments/long_horizon
folder contains the code to reproduce the N-HiTS experiment reported in the original paper only for the datasets ETT. I tried to slightly change the code to adapt the code to another dataset (i.e., ILI), but it does not work (and I am unsure on how to fix this, given that also two files for loading the datasets exist in the repo).It would be a great value being able to reproduce the experiment of the N-HiTS paper on all the datasets presented in the paper. Thanks for the nice work!
Link
No response