NX-AI / xlstm

Official repository of the xLSTM.
GNU Affero General Public License v3.0
918 stars 66 forks source link

Need Forwarding with state. #22

Open lxianl455 opened 2 weeks ago

lxianl455 commented 2 weeks ago

Translation: When training, it runs without state: def forward(self, idx: torch.Tensor) -> torch.Tensor: x = self.token_embedding(idx) x = self.emb_dropout(x) x = self.xlstm_block_stack(x) logits = self.lm_head(x) return logits

Can you give a “forward with state” version? def forward(self, idx: torch.Tensor, state) -> torch.Tensor: x = self.token_embedding(idx) x = self.emb_dropout(x) x = self.xlstm_block_stack(x, state) logits = self. lm_head(x) return logits

sieusaoml commented 2 weeks ago

Translation: When training, it runs without state: def forward(self, idx: torch.Tensor) -> torch.Tensor: x = self.token_embedding(idx) x = self.emb_dropout(x) x = self.xlstm_block_stack(x) logits = self.lm_head(x) return logits

Can you give a “forward with state” version? def forward(self, idx: torch.Tensor, state) -> torch.Tensor: x = self.token_embedding(idx) x = self.emb_dropout(x) x = self.xlstm_block_stack(x, state) logits = self. lm_head(x) return logits

https://github.com/sieusaoml/xLSTM-custom-block a custom block xlstm of my

lxianl455 commented 2 weeks ago

Yes, I want to do something similar. But in the code, is it only sLSTM that can be initialized with the previous hidden state? Can't mLSTM be initialized with the previous state?

hiimbach commented 2 weeks ago

The step() method and the forward() method of mLSTMLayer use different type of conv1d forward, so I think if you want to use hidden state, you need to use step() token by token instead of forward all of tokens at the same time.

lxianl455 commented 2 weeks ago

Yes, I am not asking to forward all of the tokens at the same time. In fact, my original model was an LSTM, which processes each token in a loop. I just want to replace this LSTM with xLSTM. But it seems that 'step' is used during inference, right? May I ask if it can backpropagate normally during training? Will the inplace operations lead to backpropagation errors?

sieusaoml commented 2 weeks ago

mLSTMLayer can be used with the previous hidden state, but backpropagate gradient in my test with context_lenght=1 has an error