Nixtla / neuralforecast

Scalable and user friendly neural :brain: forecasting algorithms.
https://nixtlaverse.nixtla.io/neuralforecast
Apache License 2.0
2.7k stars 312 forks source link

TypeError: Module.load_state_dict() got an unexpected keyword argument 'assign' #986

Closed LeonTing1010 closed 6 days ago

LeonTing1010 commented 2 months ago

What happened + What you expected to happen

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /Users/leo/web3/LLM/langchain/mlts/nf_iTransformer.py:47 in │ │ │ │ 44 # model_index=None, │ │ 45 # overwrite=True, │ │ 46 # save_dataset=True) │ │ ❱ 47 nf = NeuralForecast.load(path='./checkpoints/test_run/') │ │ 48 Y_hat_df = nf.predict().reset_index() │ │ 49 Y_hat_df = Y_hat_df[Y_hat_df['unique_id'] == '300543.SZ'] │ │ 50 Y_train_df = Y_train_df[Y_train_df['unique_id'] == '300543.SZ'] │ │ │ │ /Users/leo/web3/LLM/langchain/neuralforecast/neuralforecast/core.py:1333 in load │ │ │ │ 1330 │ │ for model in models_ckpt: │ │ 1331 │ │ │ modelname = "".join(model.split("_")[:-1]) │ │ 1332 │ │ │ model_class_name = alias_to_model.get(model_name, model_name) │ │ ❱ 1333 │ │ │ loaded_model = MODEL_FILENAME_DICT[model_class_name].load( │ │ 1334 │ │ │ │ f"{path}/{model}", kwargs │ │ 1335 │ │ │ ) │ │ 1336 │ │ │ loaded_model.alias = model_name │ │ │ │ /Users/leo/web3/LLM/langchain/neuralforecast/neuralforecast/common/_base_model.py:351 in load │ │ │ │ 348 │ │ │ content = torch.load(f, kwargs) │ │ 349 │ │ with _disable_torch_init(): │ │ 350 │ │ │ model = cls(**content["hyper_parameters"]) │ │ ❱ 351 │ │ model.load_state_dict(content["state_dict"], strict=True, assign=True) │ │ 352 │ │ return model │ │ 353 │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ TypeError: Module.load_state_dict() got an unexpected keyword argument 'assign'

Versions / Dependencies

Name: neuralforecast Version: 1.7.1 Summary: Time series forecasting suite using deep learning models Home-page: https://github.com/Nixtla/neuralforecast/ Author: Nixtla Author-email: business@nixtla.io License: Apache Software License 2.0

Reproduction script

nf = NeuralForecast.load(path='./checkpoints/test_run/')

Issue Severity

None

elephaint commented 2 months ago

Hi, do you have a piece of standalone code that I can run to reproduce this error? That would help me debug.

From the limited information it seems maybe the checkpoint you are loading is of the wrong datatype, or possibly it's a version issue with your Pytorch installation (i.e. the checkpoint was saved with a different version than Nixtla is using). But this is a bit guessing :)

LeonTing1010 commented 2 months ago

Hi, do you have a piece of standalone code that I can run to reproduce this error? That would help me debug.

From the limited information it seems maybe the checkpoint you are loading is of the wrong datatype, or possibly it's a version issue with your Pytorch installation (i.e. the checkpoint was saved with a different version than Nixtla is using). But this is a bit guessing :)

from neuralforecast.auto import AutoTSMixer, AutoTSMixerx from ray.tune.search.hyperopt import HyperOptSearch from ray import tune from neuralforecast.losses.numpy import mse, mae import matplotlib.pyplot as plt import pandas as pd

from datasetsforecast.long_horizon import LongHorizon from neuralforecast.core import NeuralForecast from neuralforecast.models import TSMixer, TSMixerx, NHITS, MLPMultivariate, iTransformer from neuralforecast.losses.pytorch import MSE, MAE

Change this to your own data to try the model

Y_df, Xdf, = LongHorizon.load(directory='./', group='ETTm2') Y_df['ds'] = pd.to_datetime(Y_df['ds'])

X_df contains the exogenous features, which we add to Y_df

X_df['ds'] = pd.to_datetime(X_df['ds']) Y_df = Y_df.merge(X_df, on=['unique_id', 'ds'], how='left')

We make validation and test splits

n_time = len(Y_df.ds.unique()) val_size = int(.2 n_time) test_size = int(.2 n_time) horizon = 96 input_size = 512 models = [

TSMixerx(h=horizon,
         input_size=input_size,
         n_series=7,
         max_steps=10,
         val_check_steps=10,
         early_stop_patience_steps=5,
         scaler_type='identity',
         dropout=0.7,
         valid_loss=MAE(),
         random_seed=12345678,
         futr_exog_list=['ex_1', 'ex_2', 'ex_3', 'ex_4'],
         ),

] nf = NeuralForecast( models=models, freq='15min')

Y_hat_df = nf.cross_validation(df=Y_df, val_size=val_size, test_size=test_size, n_windows=None ) nf.save(path='./checkpoints/test_run/', model_index=None, overwrite=True, save_dataset=True) nf = NeuralForecast.load(path='./checkpoints/test_run/')

Y_hat_df = Y_hat_df.reset_index()

for model in models: mae_model = mae(Y_hat_df['y'], Y_hat_df[f'{model}']) mse_model = mse(Y_hat_df['y'], Y_hat_df[f'{model}']) print(f'{model} horizon {horizon} - MAE: {mae_model:.3f}') print(f'{model} horizon {horizon} - MSE: {mse_model:.3f}')

elephaint commented 2 months ago

Thanks - I have zero issues executing that code. So my response is similar to #987, i.e.

Can you give more details about the machine config (OS, Python) you are using? How are you running this script?

If I'd have to guess it's a package conflict issue - so I would create a new virtual environment, install neuralforecast in that environment, and try rerunning the script.

H4Njx commented 1 month ago

@LeonTing1010 hi the environment-cpu.yml write the pytorch should >=2.0.0 but in 2.0.0 and 2.0.1 the code in https://github.com/pytorch/pytorch/blame/v2.0.0/torch/nn/modules/module.py#L1969 def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): has no keyword argument 'assign' so you should chenge the requirement from pytorch>=2.0.0 to pytorch>=2.1.0

github-actions[bot] commented 1 month ago

This issue has been automatically closed because it has been awaiting a response for too long. When you have time to to work with the maintainers to resolve this issue, please post a new comment and it will be re-opened. If the issue has been locked for editing by the time you return to it, please open a new issue and reference this one.

jmoralez commented 1 month ago

Reopening to remove the argument in versions that don't support it.