muditbhargava66 / PyxLSTM

Efficient Python library for Extended LSTM with exponential gating, memory mixing, and matrix memory for superior sequence modeling.
https://pyxlstm.readthedocs.io/
MIT License
263 stars 21 forks source link

I am getting nan values in the mlstm output #27

Closed Anb001 closed 2 weeks ago

matiashaggman commented 4 months ago

I am also having this issue, with both the slstm and the mlstm. After quick debugging I found the 'gates' tensor to be filled with nans in the forward pass of the mLSTM cell at: gates = F.linear(input, self.weight_ih, self.bias) + F.linear(h, self.weight_hh)

UYousafzai commented 4 months ago

there are several implementation errors between this repository and the research paper. but the largest one that I can spot is:

Normalizer state:

The implementation is missing the normalizer state (n_t) described in the paper. The paper includes an equation: n_t = ft * n{t-1} + i_t.

I don't see this anywhere in the implementation. (I could be wrong and maybe its snugly hidden somewhere in the utils)

on first look I can spot a few other potential issues as well but since I haven't had the time to review the code yet I was not going to comment further until I rewrote/verified the implementation, however seeing how some people might be waiting on this I will try to see if I can make up the time to verify the paper with the implementation.

UYousafzai commented 4 months ago

by introducing the normalization layer alongside a few changes in line with the paper I noticed that I get no NaN's up until 512 sequence length (that's the only length I tested).

Anb001 commented 4 months ago

can you please share the changes that you did in the code, also I wanted to ask one more thing that can we combine both slstm and mlstm blocks in the model?

UYousafzai commented 4 months ago

we should be able to construct a combination architecture using both sLSTM and mLSTM blocks, infact that is exactly how the original research paper suggest we use xLSTM.

additionally I have constructed really shitty code for experimentation pipelines: https://github.com/UYousafzai/llm_xlstm

Note: you can test the network with your own text lines using the existing Experimentation Pipeline, you should get the gist of how and if the loss is decreasing on your data, however this isn't even close to a complete code so you need to write your own Pipeline Using the current code layout.

changed code for the sLSTM is as following:

update the code and it should run with larger sequences, additionally this doesn't implement any of the stabalization techniques suggested in the research paper but I should be doing that in the near future after I verify the mLSTM implementation. (don't count on it to be too quick because I barely get free time to contribute to open source).

"""
sLSTM: Scalar Long Short-Term Memory

This module implements the sLSTM (scalar LSTM) cell and layer as described in the paper:
"xLSTM: Extended Long Short-Term Memory" by Beck et al. (2024).

The sLSTM extends the traditional LSTM by using exponential gating and a new memory mixing technique,
allowing for improved performance on various sequence modeling tasks.

Author: Mudit Bhargava
Date: June 2024
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class sLSTM(nn.Module):
    """
    sLSTM layer implementation.

    This layer applies multiple sLSTM cells in sequence, with optional dropout between layers.

    Args:
        input_size (int): Size of input features.
        hidden_size (int): Size of hidden state.
        num_layers (int): Number of sLSTM layers.
        dropout (float, optional): Dropout probability between layers. Default: 0.0.
    """

    def __init__(self, input_size, hidden_size, num_layers, dropout=0.0):
        super(sLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout

        self.layers = nn.ModuleList([sLSTMCell(input_size if i == 0 else hidden_size, hidden_size) 
                                     for i in range(num_layers)])
        self.dropout_layer = nn.Dropout(dropout)

    def forward(self, input_seq, hidden_state=None):
        """
        Forward pass of the sLSTM layer.

        Args:
            input_seq (Tensor): Input sequence of shape (batch_size, seq_length, input_size).
            hidden_state (tuple of Tensors, optional): Initial hidden state. Default: None.

        Returns:
            tuple: Output sequence and final hidden state.
        """
        batch_size, seq_length, _ = input_seq.size()

        if hidden_state is None:
            hidden_state = self.init_hidden(batch_size)

        outputs = []
        for t in range(seq_length):
            x = input_seq[:, t, :]
            for layer_idx, layer in enumerate(self.layers):
                h, c, n = hidden_state[layer_idx]
                h, c, n = layer(x, (h, c, n))
                hidden_state[layer_idx] = (h, c, n)
                x = self.dropout_layer(h) if layer_idx < self.num_layers - 1 else h
            outputs.append(x)

        return torch.stack(outputs, dim=1), hidden_state

    def init_hidden(self, batch_size):
        """Initialize hidden state for all layers."""
        return [(torch.zeros(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device),
                 torch.zeros(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device),
                 torch.ones(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device))
                #  torch.ones(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device))
                for _ in range(self.num_layers)]

class sLSTMCell(nn.Module):
    """
    sLSTM cell implementation.

    This cell uses exponential gating as described in the xLSTM paper.

    Args:
        input_size (int): Size of input features.
        hidden_size (int): Size of hidden state.
    """

    def __init__(self, input_size, hidden_size):
        super(sLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias = nn.Parameter(torch.randn(4 * hidden_size))

        self.reset_parameters()

    def reset_parameters(self):
        """Initialize parameters using Xavier uniform initialization."""
        nn.init.xavier_uniform_(self.weight_ih)
        nn.init.xavier_uniform_(self.weight_hh)
        nn.init.zeros_(self.bias)

    def forward(self, input, hx):
        """
        Forward pass of the sLSTM cell.

        Args:
            input (Tensor): Input tensor of shape (batch_size, input_size).
            hx (tuple of Tensors): Previous hidden state and cell state.

        Returns:
            tuple: New hidden state and cell state.
        """
        h, c, n = hx
        gates = F.linear(input, self.weight_ih, self.bias) + F.linear(h, self.weight_hh)

        i, f, g, o = gates.chunk(4, 1)

        i = torch.exp(i)  # Exponential input gate
        f = torch.sigmoid(f)  # having both exponential causes problems (research paper mentions to have it either as sigmoid or exponential)

        g = torch.tanh(g)
        o = torch.sigmoid(o)

        n = f * n + i

        c = f * c + i * g
        h = o * (c / n)

        return h, c, n
Anb001 commented 4 months ago

Thank you for the code. Can you now please check how can we use both slstm and mlstm simultaneously in the model?

UYousafzai commented 3 months ago

I mean you can create separate blocks of either of the mLSTM and sLSTM and combine them into a single architecture file, and then pass the around the data as you would see fit, but I think I have specified this before that even in mLSTM there are a few differences:

for example of the top of my head and just looking at the code directly.

Paper: h_t = o_t ⊙ h̃_t, h̃_t = C_t q_t / max{|n_t^T q_t|, 1} Code: h_t = o_t ⊙ h̃_t, h̃_t = C_t q_t / max(n_t^T q_t, 1) The code takes the max of the dot product directly, while the paper takes max of the absolute value. Minor difference.

now I haven't looked in depth into this but will do at some future point and correct the code (if required) otherwise I feel like the implementation is largely correct and if it works empirically you can just rewrite some of the config and how it uses these basic blocks to write your own architecture and use both mLSTM and sLSTM in combination.

Anb001 commented 3 months ago

But the issue is that using mlstm is giving me nan values at the output, that's the issue.

matiashaggman commented 3 months ago

we should be able to construct a combination architecture using both sLSTM and mLSTM blocks, infact that is exactly how the original research paper suggest we use xLSTM.

additionally I have constructed really shitty code for experimentation pipelines: https://github.com/UYousafzai/llm_xlstm

Note: you can test the network with your own text lines using the existing Experimentation Pipeline, you should get the gist of how and if the loss is decreasing on your data, however this isn't even close to a complete code so you need to write your own Pipeline Using the current code layout.

changed code for the sLSTM is as following:

update the code and it should run with larger sequences, additionally this doesn't implement any of the stabalization techniques suggested in the research paper but I should be doing that in the near future after I verify the mLSTM implementation. (don't count on it to be too quick because I barely get free time to contribute to open source).

"""
sLSTM: Scalar Long Short-Term Memory

This module implements the sLSTM (scalar LSTM) cell and layer as described in the paper:
"xLSTM: Extended Long Short-Term Memory" by Beck et al. (2024).

The sLSTM extends the traditional LSTM by using exponential gating and a new memory mixing technique,
allowing for improved performance on various sequence modeling tasks.

Author: Mudit Bhargava
Date: June 2024
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class sLSTM(nn.Module):
    """
    sLSTM layer implementation.

    This layer applies multiple sLSTM cells in sequence, with optional dropout between layers.

    Args:
        input_size (int): Size of input features.
        hidden_size (int): Size of hidden state.
        num_layers (int): Number of sLSTM layers.
        dropout (float, optional): Dropout probability between layers. Default: 0.0.
    """

    def __init__(self, input_size, hidden_size, num_layers, dropout=0.0):
        super(sLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout

        self.layers = nn.ModuleList([sLSTMCell(input_size if i == 0 else hidden_size, hidden_size) 
                                     for i in range(num_layers)])
        self.dropout_layer = nn.Dropout(dropout)

    def forward(self, input_seq, hidden_state=None):
        """
        Forward pass of the sLSTM layer.

        Args:
            input_seq (Tensor): Input sequence of shape (batch_size, seq_length, input_size).
            hidden_state (tuple of Tensors, optional): Initial hidden state. Default: None.

        Returns:
            tuple: Output sequence and final hidden state.
        """
        batch_size, seq_length, _ = input_seq.size()

        if hidden_state is None:
            hidden_state = self.init_hidden(batch_size)

        outputs = []
        for t in range(seq_length):
            x = input_seq[:, t, :]
            for layer_idx, layer in enumerate(self.layers):
                h, c, n = hidden_state[layer_idx]
                h, c, n = layer(x, (h, c, n))
                hidden_state[layer_idx] = (h, c, n)
                x = self.dropout_layer(h) if layer_idx < self.num_layers - 1 else h
            outputs.append(x)

        return torch.stack(outputs, dim=1), hidden_state

    def init_hidden(self, batch_size):
        """Initialize hidden state for all layers."""
        return [(torch.zeros(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device),
                 torch.zeros(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device),
                 torch.ones(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device))
                #  torch.ones(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device))
                for _ in range(self.num_layers)]

class sLSTMCell(nn.Module):
    """
    sLSTM cell implementation.

    This cell uses exponential gating as described in the xLSTM paper.

    Args:
        input_size (int): Size of input features.
        hidden_size (int): Size of hidden state.
    """

    def __init__(self, input_size, hidden_size):
        super(sLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias = nn.Parameter(torch.randn(4 * hidden_size))

        self.reset_parameters()

    def reset_parameters(self):
        """Initialize parameters using Xavier uniform initialization."""
        nn.init.xavier_uniform_(self.weight_ih)
        nn.init.xavier_uniform_(self.weight_hh)
        nn.init.zeros_(self.bias)

    def forward(self, input, hx):
        """
        Forward pass of the sLSTM cell.

        Args:
            input (Tensor): Input tensor of shape (batch_size, input_size).
            hx (tuple of Tensors): Previous hidden state and cell state.

        Returns:
            tuple: New hidden state and cell state.
        """
        h, c, n = hx
        gates = F.linear(input, self.weight_ih, self.bias) + F.linear(h, self.weight_hh)

        i, f, g, o = gates.chunk(4, 1)

        i = torch.exp(i)  # Exponential input gate
        f = torch.sigmoid(f)  # having both exponential causes problems (research paper mentions to have it either as sigmoid or exponential)

        g = torch.tanh(g)
        o = torch.sigmoid(o)

        n = f * n + i

        c = f * c + i * g
        h = o * (c / n)

        return h, c, n

Well this does seem to partly fix the nan issue, however I do still get the nans randomly after a number of training epochs, which is weird. The mLSTM version still needs a fix though.

UYousafzai commented 3 months ago

we should be able to construct a combination architecture using both sLSTM and mLSTM blocks, infact that is exactly how the original research paper suggest we use xLSTM. additionally I have constructed really shitty code for experimentation pipelines: https://github.com/UYousafzai/llm_xlstm Note: you can test the network with your own text lines using the existing Experimentation Pipeline, you should get the gist of how and if the loss is decreasing on your data, however this isn't even close to a complete code so you need to write your own Pipeline Using the current code layout. changed code for the sLSTM is as following: update the code and it should run with larger sequences, additionally this doesn't implement any of the stabalization techniques suggested in the research paper but I should be doing that in the near future after I verify the mLSTM implementation. (don't count on it to be too quick because I barely get free time to contribute to open source).

"""
sLSTM: Scalar Long Short-Term Memory

This module implements the sLSTM (scalar LSTM) cell and layer as described in the paper:
"xLSTM: Extended Long Short-Term Memory" by Beck et al. (2024).

The sLSTM extends the traditional LSTM by using exponential gating and a new memory mixing technique,
allowing for improved performance on various sequence modeling tasks.

Author: Mudit Bhargava
Date: June 2024
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class sLSTM(nn.Module):
    """
    sLSTM layer implementation.

    This layer applies multiple sLSTM cells in sequence, with optional dropout between layers.

    Args:
        input_size (int): Size of input features.
        hidden_size (int): Size of hidden state.
        num_layers (int): Number of sLSTM layers.
        dropout (float, optional): Dropout probability between layers. Default: 0.0.
    """

    def __init__(self, input_size, hidden_size, num_layers, dropout=0.0):
        super(sLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout

        self.layers = nn.ModuleList([sLSTMCell(input_size if i == 0 else hidden_size, hidden_size) 
                                     for i in range(num_layers)])
        self.dropout_layer = nn.Dropout(dropout)

    def forward(self, input_seq, hidden_state=None):
        """
        Forward pass of the sLSTM layer.

        Args:
            input_seq (Tensor): Input sequence of shape (batch_size, seq_length, input_size).
            hidden_state (tuple of Tensors, optional): Initial hidden state. Default: None.

        Returns:
            tuple: Output sequence and final hidden state.
        """
        batch_size, seq_length, _ = input_seq.size()

        if hidden_state is None:
            hidden_state = self.init_hidden(batch_size)

        outputs = []
        for t in range(seq_length):
            x = input_seq[:, t, :]
            for layer_idx, layer in enumerate(self.layers):
                h, c, n = hidden_state[layer_idx]
                h, c, n = layer(x, (h, c, n))
                hidden_state[layer_idx] = (h, c, n)
                x = self.dropout_layer(h) if layer_idx < self.num_layers - 1 else h
            outputs.append(x)

        return torch.stack(outputs, dim=1), hidden_state

    def init_hidden(self, batch_size):
        """Initialize hidden state for all layers."""
        return [(torch.zeros(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device),
                 torch.zeros(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device),
                 torch.ones(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device))
                #  torch.ones(batch_size, self.hidden_size, device=self.layers[0].weight_ih.device))
                for _ in range(self.num_layers)]

class sLSTMCell(nn.Module):
    """
    sLSTM cell implementation.

    This cell uses exponential gating as described in the xLSTM paper.

    Args:
        input_size (int): Size of input features.
        hidden_size (int): Size of hidden state.
    """

    def __init__(self, input_size, hidden_size):
        super(sLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias = nn.Parameter(torch.randn(4 * hidden_size))

        self.reset_parameters()

    def reset_parameters(self):
        """Initialize parameters using Xavier uniform initialization."""
        nn.init.xavier_uniform_(self.weight_ih)
        nn.init.xavier_uniform_(self.weight_hh)
        nn.init.zeros_(self.bias)

    def forward(self, input, hx):
        """
        Forward pass of the sLSTM cell.

        Args:
            input (Tensor): Input tensor of shape (batch_size, input_size).
            hx (tuple of Tensors): Previous hidden state and cell state.

        Returns:
            tuple: New hidden state and cell state.
        """
        h, c, n = hx
        gates = F.linear(input, self.weight_ih, self.bias) + F.linear(h, self.weight_hh)

        i, f, g, o = gates.chunk(4, 1)

        i = torch.exp(i)  # Exponential input gate
        f = torch.sigmoid(f)  # having both exponential causes problems (research paper mentions to have it either as sigmoid or exponential)

        g = torch.tanh(g)
        o = torch.sigmoid(o)

        n = f * n + i

        c = f * c + i * g
        h = o * (c / n)

        return h, c, n

Well this does seem to partly fix the nan issue, however I do still get the nans randomly after a number of training epochs, which is weird. The mLSTM version still needs a fix though.

because this was the sLSTM fix only, you can take the mLSTM fix from the following file, I am sure I pushed the fix, however if you don't get it solved let me know and I can look into it.

https://github.com/UYousafzai/llm_xlstm/blob/main/src/xLSTM/mlstm.py

Marilyn0321 commented 3 months ago

Hi, the mlstm code change is still always nan

UYousafzai commented 3 months ago

Hi, the mlstm code change is still always nan

did you try the fix in my repo?

if its still there I can look this over on the weekend and see what else could be causing it, additionally could you perhaps send sample of your data along with the config files? I could try seeing the issue.

Marilyn0321 commented 3 months ago

Hi, the mlstm code change is still always nan嗨,mlstm代码更改仍然总是nan

did you try the fix in my repo?你试过在我的仓库里修复了吗?

if its still there I can look this over on the weekend and see what else could be causing it, additionally could you perhaps send sample of your data along with the config files? I could try seeing the issue.如果它仍然在那里,我可以在周末看一下这个,看看还有什么可能导致它,另外,你可能会发送你的数据样本连同配置文件?我可以试着找出问题所在。

Yes, it was trying the fix in your repo; the changes to slstm in your repo worked!I got normal results, but mlstm is still the case where all results are nan

Marilyn0321 commented 3 months ago

Hi, the mlstm code change is still always nan嗨,mlstm代码更改仍然总是nan

did you try the fix in my repo?你试过在我的仓库里修复了吗?

if its still there I can look this over on the weekend and see what else could be causing it, additionally could you perhaps send sample of your data along with the config files? I could try seeing the issue.如果它仍然在那里,我可以在周末看一下这个,看看还有什么可能导致它,另外,你可能会发送你的数据样本连同配置文件?我可以试着找出问题所在。

Great, hard work to you!I used this model for further extraction of my image features.I first extracted to 8, 30, 512 tensor (8 is batchsize, 30 is the length of image sequence, 512 is the corresponding feature of each image) using resnet, and after that I got the classification result of (8, 1) by this code.

Marilyn0321 commented 3 months ago

Hi, the mlstm code change is still always nan嗨,mlstm代码更改仍然总是nan

did you try the fix in my repo?你试过在我的仓库里修复了吗?

if its still there I can look this over on the weekend and see what else could be causing it, additionally could you perhaps send sample of your data along with the config files? I could try seeing the issue.如果它仍然在那里,我可以在周末看一下这个,看看还有什么可能导致它,另外,你可能会发送你的数据样本连同配置文件?我可以试着找出问题所在。

config.zip Keep me posted if there's anything else you need from me.Hard work!

UYousafzai commented 2 months ago

so here is the catch.

even though I am not getting any NaN's with mlstm in my small test datasets and even GermanWikiDataset, however upon inspecting the research publication more carefully I realize that there is a possibility that the research mentions the use of either an exponential function or a sig in forget gate, whereas for the input gate its only exponential and I feel like that maybe the cause of it, so depending on the dataset it could either demand a sig activation or an exp one, I have pushed the new changes to the repo with Sig based forget gate and you can test it with that and if that fixes the issue we can further dive into the basis of why that is the case.

github-actions[bot] commented 3 weeks ago

Stale issue message