Open hackiey opened 7 years ago
Hello!
The way my implementation should solve that (it is currently broken for models which need (batch_size, sequence_length, input_dim), but not if only adding a layer) is via statefulness.
The controller does indeed only see a sequence of length 1 every time it is called, after which memory instructions are evaluated, i.e. memory is written to and read. How does the controller achieve learning through time then?
Via statefulness: In that case, the controller saves its state (state = its own input for the next round) at the end of every batch, thus "understanding" e.g 10 batches of length 1 as 1 batch of length 10. This is a very cool feature of Keras that they made that corner case possible via a simple flag. Kinda the moment I fell in love with Keras :)
The only problem right now is: If the controller is a layer, everything works fine[1]. If the controller is indeed a whole model with its own optimiser, everything breaks down, as the optimiser doesnt understand the merging of several batches (it tries to optimize after each batch) into one batch of higher sequence_length.[2]
I'm currently not sure how to fix that elegantly. Maybe I need to understand a controller only as a list of layers and strip it of its optimizer. Maybe I understand Keras better and find another way.
Anyway, thanks for asking! This is certainly one of the more non-trivial questions about the implementation. In previous versions I carried around the state myself, but that was rather painful and highly prone to internal API changes.
If you have another/follow-up questions feel free to ask.
[1] You can test that via controller_model="lstm", a feature available in the development branch in approximately half an hour or something. This adds a single layer of LSTM as a controller.
[2] At least thats whats going wrong as far as I understand it. If anyone has another hypothesis, let me know! I'm only in the business of phrickling with Keras since a month or something.
Thanks for your answers!
Now I have a follow-up question, if you have two samples A of shape (2, 10, 1) , for lstm controller, the input will become B (20, 1, 1), if we feed A to a regular lstm, the 2nd sample's state should be reset, means that we feed B to lstm controller, the 11st sample's state should be reset, or the 2nd sample will get an initial state which is the 1st sample's last state. maybe I understand it wrong?
If I have a sample A of shape (7,5,3), and we feed it to the NTM, the controller will see 5 times in a row input of shape (7,3) (dense controller) or shape (7,1,3) (lstm controller).
There is no mixing of state between batches, every batch is treated separatedly. State is carried for all the 5 timesteps.
After one call of the NTM with sample A, all the state is completely lost (because I havent implemented statefulness for the whole NTM-layer yet) and the next sample of e.g. shape (7,26,3) will start with freshly initialised state (I think just zero).
Did that clear that things up? Otherwise I probably didnt get what you were asking.
Thanks again. I understand that the NTM layer will start with freshly initialized state, but the lstm controller may not, cause the lstm controller is stateful, the state will pass through samples.
For example, there are 7 samples of shape (7,5,3), the state of lstm controller should be passed through timesteps, cause the batch_size is 1, the state will be passed through samples, the (1st sample, 5th timesteps)'s state will be used in (2nd sample, 1th timestep)
there are 7 samples of shape (7,5,3)
Do you mean 7 samples, each of batch_size=7 (in which case I would suggest using another prime for the number of samples)? In that case the different states simply do not interfere, as every batch is computed in parallel (both logically and in hardware I hope).
Otherwise: Yes, at the beginning of each NTM-sample we have to reset the controller. Admittedly, that is currently not in the code I think (all of the controller-with-state-stuff is currently untested due to a massive fail in the testing routine), but should be achieveable via self.controller.reset_state() or something like this.
Im currently rewriting some stuff there which will be in the development branch soon, I hope. Maybe its clearer after that.
Damn, getting those stateful controllers to run again is quite frusttrating. Especially as I know that it once worked!
There is a fallback solution of just grabbing the state of the controller, carrying it through the step function in ntm.py and reinserting them in _run_controller. This will work ok for a single layer of LSTM, but may be quite a hassle for huge controller models.
Anyway, I'm still on it!
Maybe you should write some lstm codes in ntm step function 🤕, I found that in another ntm keras implementation.
I think this is one drawback of keras, it is not very flexible, at line 334 in recurrent.py, the step function and initial_state are passed to backend, and it runs over the sequence, so we can't run one step and learn the error through all timesteps.
last_output, outputs, states = K.rnn(self.step,
preprocessed_input,
initial_state,
go_backwards=self.go_backwards,
mask=mask,
constants=constants,
unroll=self.unroll,
input_length=input_shape[1])
If all lstm and ntm in one step function, there should be no problems.
my code is very, very roughly based on the the seya-repository[1]: I translated all his code from Keras 0.3 to Keras 2.0, then realised that I finally understood how an NTM works and rewrote it from scratch. He originally hardcoded an LSTM-layer into the NTM and then called the K.rnn function directly, which, if I remember right, worked somehow.
Did you mean his solution?
In some playground wrapper layer (see gist: https://gist.github.com/flomlo/c2429eea243f082a6ea49d493e687256) I tried all possibilities:
All fail with the same error! So I'm convinced we're looking at some actual keras bug here. My current hypothesis is that it may have to do something with different sessions, Keras does some very unhealthy and crazy shit with tf.Session.
I will investigate the session-theory a bit further and will separatly try to call the LSTM-step function directly (there is actually no reason to have the complete control loop build by K.rnn if the sequence length is only 1).
If both of those solutions fail, I seriously consider giving up and just filing some keras bug (as the Wrapper-Layer really should work imho, its a very simple thing. And if the Wrapper-Layer works, so will the NTM-Layer).
It may also be very intersting if this behaviour also affects the Theano-Backend. If it doesnt, I strongly think it is actually a Keras bug.
[1] https://github.com/EderSantana/seya/blob/master/seya/layers/ntm.py
Ok, calling the LSTM step-function directly works. But damn, is that ugly! Really not satisfying. And very bad to allow arbitrary controllers with that.
That would mean I would have to restrict controllers of the NTM to either an arbitrary, but stateless model, and could allow for stateful models only by hardcoding them into ntm.py.
Next weekend I will look into the session-theory mentioned above. I do not accept the current behaviour as non-buggy or unfixable.
One (low-quality) Keras-fix could be, that RNN, called with sequences of length 1 will get special treatment by K.rnn, that could work? But it is a low-quality patch, I doubt fchollet will allow it into Keras.
Here, https://github.com/SigmaQuan/NTM-Keras/blob/master/ntm.py, you can see they don't use lstm layer, and wirte them by themselves.
Use stateful lstm and just one length is not a good idea(here is a keras issue talks about it), you can't refresh the state easily, the state will be carried from first sample to the last, that means lstm cannot learn the error through time.
Cause the step function will be runned by backend, the controller's step function should be wrote in ntm's step function, that once the ntm recieved a new sample it can refresh the lstm and ntm hidden states.
Oh man, if I would have been aware of SigmaQuans NTM-Keras solution I wouldnt have build this one :D
I agree with you with the stateful + single-step-execution warning: Testing has shown that stateful=True doesnt force stuff to be executed in sequential, instead it is executed in parallel. I dont know if there is a variable-lock in tensorflow, and even if, that would be highly tensorflow-specific.
If I cannot find any better solution, I think I might try to convince fchollet to accept an K.rnn patch which does special case treatment if by compile time it is known that the sequence-length of the RNN is 1 (that one should be easy to write). One still has to load and set all the state manually for the whole controller model, which is a hassle, but managable.
By the way, here is the implementation which does manual state saving and calls the step function directly: https://gist.github.com/flomlo/6d81717dae4ed6d0466c391436c552ee I still have to test it, but I'm quite sure it works the same as an LSTM-layer, just slower.
update: still working on a keras patch which fixes the behaviour for calling rnn with step-size=1.
I'm confused about the _run_controller, if I use the lstm as controller, the dimension of controller_input should be 3, so the controller_input = controller_input[:, None, :]
because the lstm call function should have the shape [batch_size, time_length, input_dim], but now we give it the shape [batch_size, 1, input_dim], and does it could learn the error through time?