Closed miaozhixu closed 3 months ago
According the meaning of "return_last_state", I consider the codes in xlstm/blocks/slstm/layer.py to achieve CausalConv1d.forward()'s return value need to swap.
if return_last_state:
x_conv, conv_state = self.conv1d(x, conv_state, return_last_state=return_last_state)
else:
x_conv = self.conv1d(x, conv_state, return_last_state=return_last_state)
When I set return_last_state in README sample:
y = xlstm_stack(x, return_last_state=True)
and I swap the lines in slstm layer src, it is stop in xlstm_block.py Line 76.
so I modify it as following to check the return_last_state:
if kwargs['return_last_state']:
x = x + self.xlstm(self.xlstm_norm(x), **kwargs)[0]
else:
x = x + self.xlstm(self.xlstm_norm(x), **kwargs)
now, x(tensor) will not add a tuple.
In source code: xlstm/blocks/slstm/layer.py, Line 134:
if return_last_state:
x_conv = self.conv1d(x, conv_state, return_last_state=return_last_state)
else:
x_conv, conv_state = self.conv1d(
x, conv_state, return_last_state=return_last_state
)
These lines of code is calling forward method in Class CausalConv1d.In source code: xlstm/components/conv.py, Line 126:
if return_last_state:
return y[:, :, : -self.pad].transpose(2, 1), x[:, -self.pad :]
else:
return y[:, :, : -self.pad].transpose(2, 1)
When "return_last_state" 's value is false, the method "forward" in conv.py, will only return one value, not include the last state of x. But in the layer.py, Line 137, it's expecting the method will return two values.
Should I swap the lines of layer.py (Line 135 and 137)?