Closed DeNeutoy closed 7 years ago
@matt-gardner This last commit has everything you need to run the adaptive example which crashes - you'll have to change the path to the data files but apart from that it should recreate the problem I was having earlier.
This is ready for review. This looks kinda big, but it's inflated by the fact I added in the super
calls everywhere.
The main changes are:
_build_model()
method which runs the memory network loop to be a separate class, which is called. This allows different methods for controlling the memory network loop (either using a fixed number of steps or using adaptive computation). MemoryNetworkSolver
now has a new method which is a single memory network step, which is used by the recurrence class to actually perform the loop. AdaptiveRecurrence
class, we simply do an AdaptiveStep
over the inputs, which is a Keras Layer. When this AdaptiveStep
Layer is called, we first set up some variables which we need for the while loop. Most of the actual function definition is within adaptive_memory_hop
, which is run by the tensorflow while loop. Within adaptive_memory_hop
, first we run the memory network step as defined by the memory network and then compute halting probabilities based off of the newly generated memory representation. Most of the rest of this method deals with the fact that this computation is batched, but different elements of the batch will want to do different numbers of memory steps. This layer is input dimension agnostic, so it works for all of the solvers, regardless of the size of background knowledge etc.I've also added a test which checks that TF and keras are optimising the same variables. Let me know if you think there are more tests I should add, happy to do so.
Also the logic for what is happening with the halting is complicated, might be best for me to explain that on a whiteboard as there are various parts to it which look very similar but are actually time-delayed by one iteration etc.
New PR for this after a rebase, so Semaphore can run tests.