NX-AI / xlstm

Official repository of the xLSTM.
https://www.nx-ai.com/
Apache License 2.0
1.42k stars 101 forks source link

Two modification of codes is needed to run sample program in version 1.0.4 #24

Closed miaozhixu closed 3 months ago

miaozhixu commented 5 months ago

In source code: xlstm/blocks/slstm/layer.py, Line 134: if return_last_state: x_conv = self.conv1d(x, conv_state, return_last_state=return_last_state) else: x_conv, conv_state = self.conv1d( x, conv_state, return_last_state=return_last_state ) These lines of code is calling forward method in Class CausalConv1d.

In source code: xlstm/components/conv.py, Line 126: if return_last_state: return y[:, :, : -self.pad].transpose(2, 1), x[:, -self.pad :] else: return y[:, :, : -self.pad].transpose(2, 1)

When "return_last_state" 's value is false, the method "forward" in conv.py, will only return one value, not include the last state of x. But in the layer.py, Line 137, it's expecting the method will return two values.

Should I swap the lines of layer.py (Line 135 and 137)?

miaozhixu commented 5 months ago

According the meaning of "return_last_state", I consider the codes in xlstm/blocks/slstm/layer.py to achieve CausalConv1d.forward()'s return value need to swap. if return_last_state: x_conv, conv_state = self.conv1d(x, conv_state, return_last_state=return_last_state) else: x_conv = self.conv1d(x, conv_state, return_last_state=return_last_state)

miaozhixu commented 5 months ago

When I set return_last_state in README sample: y = xlstm_stack(x, return_last_state=True) and I swap the lines in slstm layer src, it is stop in xlstm_block.py Line 76. so I modify it as following to check the return_last_state: if kwargs['return_last_state']: x = x + self.xlstm(self.xlstm_norm(x), **kwargs)[0] else: x = x + self.xlstm(self.xlstm_norm(x), **kwargs)

now, x(tensor) will not add a tuple.