harrisonvanderbyl / rwkvstic

Framework agnostic python runtime for RWKV models
https://hazzzardous-rwkv-instruct.hf.space
MIT License
144 stars 18 forks source link

getState and setState don't quite capture the entire state of the model #14

Open tanitna opened 1 year ago

tanitna commented 1 year ago

I don't know if this is more of a documentation thing or a functionality thing, but here's my issue. I think I see what getState is intended to do, and while getState() does capture all the tensors, it doesn't store lastToken, so for instance from the documentation I would expect output_1 and output_2 to have the same distribution.

saved_state = model.getState()
output_1 = self.model.forward(number=1)
model.setState(saved_state)
output_2 = self.model.forward(number=1)

If that's the behaviour I really want, I would need to do

saved_state = model.getState()
saved_lastToken = model.lastToken
output_1 = self.model.forward(number=1)
model.setState(saved_state)
model.lastToken = saved_lastToken
output_2 = self.model.forward(number=1)

So I don't know what is preferable, to do something like

def setState(self, statePair):
    self.myState = statePair[0]
    self.lastToken = statePair[1]

def getState(self):
    return self.myState, self.lastToken

Or to just put in the documentation that there's an additional type of state, or something else entirely.

harrisonvanderbyl commented 1 year ago

thank you, I will ammend the functionality of get and set state in a future update