philipperemy / n-beats

Keras/Pytorch implementation of N-BEATS: Neural basis expansion analysis for interpretable time series forecasting.
MIT License
855 stars 163 forks source link

Keras backend does not support input_dim > 1 #58

Closed martin-studna closed 2 years ago

martin-studna commented 2 years ago

Readme says that Keras backend support input_dim > 1. I have tried to set the input_dim greater than one, and the model throws an error during the first training epoch.

philipperemy commented 2 years ago

@martin-studna It seems to work for me. You have to also specify in the constructor of NBeatsKeras. The information needs to be known before creating the model.

Try to run this:

import numpy as np

from nbeats_keras.model import NBeatsNet as NBeatsKeras

def main():
    # https://keras.io/layers/recurrent/
    num_samples, time_steps, input_dim, output_dim = 50_000, 10, 2, 1  # <--------------- I set input_dim = 2

    # Definition of the model.
    model_keras = NBeatsKeras(input_dim=input_dim,   # <--------------- I add it here input_dim = 2
                              backcast_length=time_steps, forecast_length=output_dim,
                              stack_types=(NBeatsKeras.GENERIC_BLOCK, NBeatsKeras.GENERIC_BLOCK),
                              nb_blocks_per_stack=2, thetas_dim=(4, 4), share_weights_in_stack=True,
                              hidden_layer_units=64)

    model_keras.compile(loss='mae', optimizer='adam')
    x = np.random.uniform(size=(num_samples, time_steps, input_dim))
    y = np.mean(x, axis=1, keepdims=True)

    # Split data into training and testing datasets.
    c = num_samples // 10
    x_train, y_train, x_test, y_test = x[c:], y[c:], x[:c], y[:c]

    # Train the model.
    print('Keras training...')
    print(x_train.shape, y_train.shape)
    model_keras.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=20, batch_size=128)
    predictions_keras_forecast = model_keras.predict(x_test)
    print(predictions_keras_forecast.shape)

if __name__ == '__main__':
    main()

Logs

2022-01-14 14:05:08.235545: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Keras training...
(45000, 10, 2) (45000, 1, 2)
Epoch 1/20
352/352 [==============================] - 3s 4ms/step - loss: 0.0626 - val_loss: 0.0541
Epoch 2/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0540 - val_loss: 0.0534
Epoch 3/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0537 - val_loss: 0.0528
Epoch 4/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0534 - val_loss: 0.0541
Epoch 5/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0533 - val_loss: 0.0537
Epoch 6/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0530 - val_loss: 0.0528
Epoch 7/20
352/352 [==============================] - 2s 5ms/step - loss: 0.0530 - val_loss: 0.0520
Epoch 8/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0531 - val_loss: 0.0524
Epoch 9/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0530 - val_loss: 0.0522
Epoch 10/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0529 - val_loss: 0.0526
Epoch 11/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0530 - val_loss: 0.0520
Epoch 12/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0529 - val_loss: 0.0553
Epoch 13/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0528 - val_loss: 0.0522
Epoch 14/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0529 - val_loss: 0.0523
Epoch 15/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0528 - val_loss: 0.0539
Epoch 16/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0528 - val_loss: 0.0519
Epoch 17/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0526 - val_loss: 0.0538
Epoch 18/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0526 - val_loss: 0.0522
Epoch 19/20
352/352 [==============================] - 1s 4ms/step - loss: 0.0526 - val_loss: 0.0522
Epoch 20/20
352/352 [==============================] - 1s 3ms/step - loss: 0.0527 - val_loss: 0.0531
(5000, 1, 2)

I will close the issue. Let me know if it worked/did not work for you.