Nixtla / neuralforecast

Scalable and user friendly neural :brain: forecasting algorithms.
https://nixtlaverse.nixtla.io/neuralforecast
Apache License 2.0
3.12k stars 359 forks source link

[FIX] Fix Tweedie loss #1164

Closed elephaint closed 1 month ago

elephaint commented 1 month ago

Fixes #1159.

Test code that errors without this fix, but gives reasonable predictions with this fix:

from neuralforecast import NeuralForecast
from neuralforecast.models import NHITS
from neuralforecast.losses.pytorch import DistributionLoss
from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic
import pandas as pd
import matplotlib.pyplot as plt

Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train
Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 

nf = NeuralForecast(
    models=[
        NHITS(
            h=12,
            input_size=48,
            max_steps=200,
            scaler_type="minmax",
            loss=DistributionLoss('Tweedie', rho=1.5, validate_args=False),
            enable_model_summary=False,
            enable_checkpointing=False,
            logger=False
        )
    ],
    freq="M")

nf.fit(df=Y_train_df, static_df=AirPassengersStatic)
forecasts = nf.predict(futr_df=Y_test_df)

# Plot quantile predictions
Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])
plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)
plot_df = pd.concat([Y_train_df, plot_df])

plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)
plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')
plt.plot(plot_df['ds'], plot_df['NHITS-median'], c='blue', label='median')
plt.fill_between(x=plot_df['ds'][-12:], 
                 y1=plot_df['NHITS-lo-90'][-12:].values,
                 y2=plot_df['NHITS-hi-90'][-12:].values,
                 alpha=0.4, label='level 90')
plt.legend()
plt.grid()
review-notebook-app[bot] commented 1 month ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Antoine-Schwartz commented 1 month ago

It definitely fixed my problem, and the results seem consistent. However I'm not in a position to compare with other implementations to say if it's better or worse than expected ;)