jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.74k stars 599 forks source link

[Code review request] N-Beats implemented with SELU() instead of RELU() #444

Open pnmartinez opened 3 years ago

pnmartinez commented 3 years ago

Hi all,

If everything is in check, we can open a pull request to include this as a variant of N-beats.

I am interested in the SELU activation function for Self-Normalizing-Networks (SNNs, see, e.g., Pytorch docs). I didn't find any N-Beats corrected to use SELU and its requirements (i.e. AlphaDropout, proper weights init), so I made an implementation myself patching the sub_modules.py file in pytorch-forecasting.

It would be great if any of you with experience with these concepts -NBeats architecture, pytorch-forecasting, or SELU()- could review whether everything is right in my implementation.

The implementation below and as Gist (tiny, commented modifications of the /models/nbeats/sub_modules.py in the lib): https://gist.github.com/pnmartinez/fef1f488497fa85a2cc1626af2a5b4bd

"""
Implementation of ``nn.Modules`` for N-Beats model,

* modified to use SELU function. That implies 3 changes: 
- changing DropOut to AlphaDropout, 
- weights init to `lecun_normal`, 
- and RELU by SELU.

Sources: 
- About SELU: https://mlfromscratch.com/activation-functions-explained/#selu
- About SELU in PyTorch: https://pytorch.org/docs/master/generated/torch.nn.SELU.html#torch.nn.SELU
"""
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

def linear(input_size, output_size, bias=True, dropout: int = None):
    lin = nn.Linear(input_size, output_size, bias=bias)
    if dropout is not None:
    ############################BELOW IS NEW#################################################################
    # "Alpha Dropout is a type of Dropout that maintains the self-normalizing property. Alpha Dropout goes 
    # hand-in-hand with SELU activation function, which ensures that the outputs have zero mean and unit standard deviation."
    # Source: https://pytorch.org/docs/stable/generated/torch.nn.AlphaDropout.html#alphadropout
        return nn.Sequential(nn.AlphaDropout(dropout), lin) 
        # return nn.Sequential(nn.Dropout(dropout), lin)
    #############################ABOVE IS NEW##################################################################
    else:
        return lin

def linspace(backcast_length: int, forecast_length: int, centered: bool = False) -> Tuple[np.ndarray, np.ndarray]:
    if centered:
        norm = max(backcast_length, forecast_length)
        start = -backcast_length
        stop = forecast_length - 1
    else:
        norm = backcast_length + forecast_length
        start = 0
        stop = backcast_length + forecast_length - 1
    lin_space = np.linspace(start / norm, stop / norm, backcast_length + forecast_length, dtype=np.float32)
    b_ls = lin_space[:backcast_length]
    f_ls = lin_space[backcast_length:]
    return b_ls, f_ls

class NBEATSBlock(nn.Module):
    def __init__(
        self,
        units,
        thetas_dim,
        num_block_layers=4,
        backcast_length=10,
        forecast_length=5,
        share_thetas=False,
        dropout=0.1,
    ):
    ##################BELOW IS NEW##########################
        # We add this tiny attribute to  our SELU-ready,
        # properly-initialized Block
        self.nbeats_ready_for_selu = True
    ###################ABOVE IS NEW#######################

        super().__init__()
        self.units = units
        self.thetas_dim = thetas_dim
        self.backcast_length = backcast_length
        self.forecast_length = forecast_length
        self.share_thetas = share_thetas

        fc_stack = [
            nn.Linear(backcast_length, units),
            nn.SELU()
            # nn.ReLU(),
        ]
        for _ in range(num_block_layers - 1):
            fc_stack.extend([linear(units, units, dropout=dropout), nn.SELU()]) #nn.ReLU()])
        self.fc = nn.Sequential(*fc_stack)

        if share_thetas:
            self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False)
        else:
            self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False)
            self.theta_f_fc = nn.Linear(units, thetas_dim, bias=False)

    ############################BELOW IS NEW############################################
        self.init_weights()

    def init_weights(self):
        """
        Weight initialization to achieve Self-Normalizing Networks (SNNs),
        i. e. the main feature obtained by using SELU.

        'When using kaiming_normal or kaiming_normal_ for initialisation, 
        nonlinearity='linear' should be used instead of nonlinearity='selu' 
        in order to get Self-Normalizing Neural Networks.'

        Sources: 
        - https://pytorch.org/docs/master/generated/torch.nn.SELU.html#selu
        - https://pytorch.org/docs/master/nn.init.html#torch-nn-init
        - https://stackoverflow.com/a/49433937
        """
        # print("initialize for SELU")

        def init_for_selu(m):
            if type(m) == torch.nn.Linear:
                # print("initialized")
                torch.nn.init.kaiming_normal_(m.weight, nonlinearity = 'linear')  

                if m.bias != None:
                    nn.init.constant_(m.bias, 0)

        self.apply(init_for_selu)      
    ##############################ABOVE IS NEW#####################################

    def forward(self, x):
        return self.fc(x)

class NBEATSSeasonalBlock(NBEATSBlock):
    def __init__(
        self,
        units,
        thetas_dim=None,
        num_block_layers=4,
        backcast_length=10,
        forecast_length=5,
        nb_harmonics=None,
        min_period=1,
        dropout=0.1,
    ):
        if nb_harmonics:
            thetas_dim = nb_harmonics
        else:
            thetas_dim = forecast_length
        self.min_period = min_period

        super().__init__(
            units=units,
            thetas_dim=thetas_dim,
            num_block_layers=num_block_layers,
            backcast_length=backcast_length,
            forecast_length=forecast_length,
            share_thetas=True,
            dropout=dropout,
        )

        backcast_linspace, forecast_linspace = linspace(backcast_length, forecast_length, centered=False)

        p1, p2 = (thetas_dim // 2, thetas_dim // 2) if thetas_dim % 2 == 0 else (thetas_dim // 2, thetas_dim // 2 + 1)
        s1_b = torch.tensor(
            [np.cos(2 * np.pi * i * backcast_linspace) for i in self.get_frequencies(p1)], dtype=torch.float32
        )  # H/2-1
        s2_b = torch.tensor(
            [np.sin(2 * np.pi * i * backcast_linspace) for i in self.get_frequencies(p2)], dtype=torch.float32
        )
        self.register_buffer("S_backcast", torch.cat([s1_b, s2_b]))

        s1_f = torch.tensor(
            [np.cos(2 * np.pi * i * forecast_linspace) for i in self.get_frequencies(p1)], dtype=torch.float32
        )  # H/2-1
        s2_f = torch.tensor(
            [np.sin(2 * np.pi * i * forecast_linspace) for i in self.get_frequencies(p2)], dtype=torch.float32
        )
        self.register_buffer("S_forecast", torch.cat([s1_f, s2_f]))

    def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
        x = super().forward(x)
        amplitudes_backward = self.theta_b_fc(x)
        backcast = amplitudes_backward.mm(self.S_backcast)
        amplitudes_forward = self.theta_f_fc(x)
        forecast = amplitudes_forward.mm(self.S_forecast)

        return backcast, forecast

    def get_frequencies(self, n):
        return np.linspace(0, (self.backcast_length + self.forecast_length) / self.min_period, n)

class NBEATSTrendBlock(NBEATSBlock):
    def __init__(
        self,
        units,
        thetas_dim,
        num_block_layers=4,
        backcast_length=10,
        forecast_length=5,
        dropout=0.1,
    ):
        super().__init__(
            units=units,
            thetas_dim=thetas_dim,
            num_block_layers=num_block_layers,
            backcast_length=backcast_length,
            forecast_length=forecast_length,
            share_thetas=True,
            dropout=dropout,
        )

        backcast_linspace, forecast_linspace = linspace(backcast_length, forecast_length, centered=True)
        norm = np.sqrt(forecast_length / thetas_dim)  # ensure range of predictions is comparable to input

        coefficients = torch.tensor([backcast_linspace ** i for i in range(thetas_dim)], dtype=torch.float32)
        self.register_buffer("T_backcast", coefficients * norm)

        coefficients = torch.tensor([forecast_linspace ** i for i in range(thetas_dim)], dtype=torch.float32)
        self.register_buffer("T_forecast", coefficients * norm)

    def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
        x = super().forward(x)
        backcast = self.theta_b_fc(x).mm(self.T_backcast)
        forecast = self.theta_f_fc(x).mm(self.T_forecast)
        return backcast, forecast

class NBEATSGenericBlock(NBEATSBlock):
    def __init__(
        self,
        units,
        thetas_dim,
        num_block_layers=4,
        backcast_length=10,
        forecast_length=5,
        dropout=0.1,
    ):
        super().__init__(
            units=units,
            thetas_dim=thetas_dim,
            num_block_layers=num_block_layers,
            backcast_length=backcast_length,
            forecast_length=forecast_length,
            dropout=dropout,
        )

        self.backcast_fc = nn.Linear(thetas_dim, backcast_length)
        self.forecast_fc = nn.Linear(thetas_dim, forecast_length)

    def forward(self, x):
        x = super().forward(x)

        theta_b = F.relu(self.theta_b_fc(x))
        theta_f = F.relu(self.theta_f_fc(x))

        return self.backcast_fc(theta_b), self.forecast_fc(theta_f)
pnmartinez commented 3 years ago

@jdb78 I have updated this to pick other activation functions as a parameter like in:

net = NBeats.from_dataset(
        training,
        activation_fn = 'lrelu' # RELU by default
    )

This way is easier to loop on different activation functions. In my case, Leaky_RELU achieves same performance much faster (image).

imagen

I have a fork of my own with that (__init__.py must be also modified). Let me know and I'll open a pull request.

New Gist.

pnmartinez commented 3 years ago

Hi,

I am sharing some results on the performance of other activations.

In the problem I am attacking, it seems that some variants of RELU achieve same performance, but Leaky-RELU is faster.

@jdb78 I would consider making it the default in case this result is proofed to hold in other problems as well.

leaky_relu_fastest

jdb78 commented 3 years ago

Very interesting analysis. Do you want to open a PR?

jdb78 commented 3 years ago

Maybe also relevant for #471

pnmartinez commented 3 years ago

Hi Jan,

Throughout the week I will try to open it.

Cheers!


De: Jan Beitner @.> Enviado: jueves, 29 de abril de 2021 13:15 Para: jdb78/pytorch-forecasting @.> Cc: Pablo @.>; Author @.> Asunto: Re: [jdb78/pytorch-forecasting] [Code review request] N-Beats implemented with SELU() instead of RELU() (#444)

Very interesting analysis. Do you want to open a PR?

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHubhttps://github.com/jdb78/pytorch-forecasting/issues/444#issuecomment-829143984, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AHEB2L6YQTE7DFSXDQSSOLTTLE5WDANCNFSM43BMTA2A.