NX-AI / xlstm

Official repository of the xLSTM.
GNU Affero General Public License v3.0
918 stars 66 forks source link

Two modification of codes is needed to run sample program in version 1.0.4 #24

Open miaozhixu opened 2 weeks ago

miaozhixu commented 2 weeks ago

In source code: xlstm/blocks/slstm/layer.py, Line 134: if return_last_state: x_conv = self.conv1d(x, conv_state, return_last_state=return_last_state) else: x_conv, conv_state = self.conv1d( x, conv_state, return_last_state=return_last_state ) These lines of code is calling forward method in Class CausalConv1d.

In source code: xlstm/components/conv.py, Line 126: if return_last_state: return y[:, :, : -self.pad].transpose(2, 1), x[:, -self.pad :] else: return y[:, :, : -self.pad].transpose(2, 1)

When "return_last_state" 's value is false, the method "forward" in conv.py, will only return one value, not include the last state of x. But in the layer.py, Line 137, it's expecting the method will return two values.

Should I swap the lines of layer.py (Line 135 and 137)?

miaozhixu commented 2 weeks ago

According the meaning of "return_last_state", I consider the codes in xlstm/blocks/slstm/layer.py to achieve CausalConv1d.forward()'s return value need to swap. if return_last_state: x_conv, conv_state = self.conv1d(x, conv_state, return_last_state=return_last_state) else: x_conv = self.conv1d(x, conv_state, return_last_state=return_last_state)

miaozhixu commented 2 weeks ago

When I set return_last_state in README sample: y = xlstm_stack(x, return_last_state=True) and I swap the lines in slstm layer src, it is stop in xlstm_block.py Line 76. so I modify it as following to check the return_last_state: if kwargs['return_last_state']: x = x + self.xlstm(self.xlstm_norm(x), **kwargs)[0] else: x = x + self.xlstm(self.xlstm_norm(x), **kwargs)

now, x(tensor) will not add a tuple.