zhenruiliao / tension

A Python package for FORCE training chaotic neural networks
MIT License
5 stars 2 forks source link

Possibility to predict the model from a saved checkpoint, avoiding having to fit the model multiple times #34

Closed dmnburrows closed 1 year ago

dmnburrows commented 1 year ago

Hey,

I'm playing around with the example datasets of the ConstrainedNoFeedbackESN to understand how the toolbox works. One issue I'm having is that I have to re-fit the model (model.fit) each time I want to make predictions with model.predict. In order to simulate with the learned weights, I always have to fit the model first - this is not ideal if I want to learn a large network with many parameters and then do simulations on this network. I want to avoid having to fit the model every time I want to use model.predict. Is there some way to do this?

I have tried saving the model.fit object but it's not serialisable and I have also tried using tf.keras.callbacks.ModelCheckpoint to save the learned weights (see code below) - when I do this and use a brand new model, I am not able to load in the weights without first training the model, and if i just train the model for 1 time step and then load the learned weights the predict output looks very different from the original model even though the weights and parameters should be the same.

Thanks so much for the help!

CODE for how I am using callbacks to save the weights

target_transposed = np.transpose(target).astype(np.float32) # convert to shape (timestep, number of neurons) u = 1 # number of inputs, by default the forward pass does not use inputs so this is a stand-in n = target_transposed.shape[1] # number of neurons tau = 1.5 # neuron time constant dt = 0.25 # time step alpha = 1 # gain on P matrix at initialization m = n # output dim equals the number of recurrent neurons g = 1.25 # gain parameter controlling network chaos p_recurr = 0.1 # (1 - p_recurr) of recurrent weights are randomly set to 0 and not trained max_epoch = 10 structural_connectivity = np.ones((n, n)) # region connectivity matrix; set to all 1's since only looking at one subsection of the brain noise_param = (0, 0.001) # mean and std of white noise injected to the forward pass x_t = np.zeros((target_transposed.shape[0], u)).astype(np.float32) # stand-in input

checkpoint_path = "training_1/cp.ckpt" checkpoint_dir = os.path.dirname(checkpoint_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1)

tf.random.set_seed(123) esn_layer = ConstrainedNoFeedbackESN(units=n, activation='tanh', dtdivtau=dt/tau, p_recurr=p_recurr, structural_connectivity=structural_connectivity, noise_param=noise_param, seed=123)

model = BioFORCEModel(force_layer=esn_layer, alpha_P=alpha) model.compile(metrics=["mae"])

history = model.fit(x=x_t, y=target_transposed, epochs=max_epoch, validation_data=(x_t, target_transposed) , callbacks=[cp_callback]) _1 = model.predict(x_t)

lubin-liu commented 1 year ago

If I understood the issue correctly, I think the cause may be the gaussian white noise that is injected into the forward pass of ConstrainedNoFeedbackESN as specified by noise_param. From `Training FORCE with Zebrafish neural data.ipynb', if you re-run the two prediction cells from the image below multiple times, sometimes the model looks like it hasn't learned (very high error), and that the output is not reproducible despite the inputs and weights having not changed. You can set the noise_param to (0.0, 0.0) in the layer definition which would zero out the noise during the forward pass. Edit: I think setting the noise to 0.0 doesn't work for this layer with the default initial states (neuron firing rates all zero), so you would need to figure out an initial state that allows learning to occur, which from past experimentations was tricky.

From `Training FORCE with Zebrafish neural data.ipynb',

dmnburrows commented 1 year ago

Thanks for the prompt response! I have looked into the noise as a potential source of the problem - it is definitely true that because of some noise I cannot reproduce exactly the same trajectories when I re-run the model.predict on the same model object after learning (see image below, each trajectory is a different run of model.predict on the same esn_layer after one instance of training - 50 epochs of model.fit)

Screenshot 2023-03-23 at 09 57 14

However, I don't think this is the main source of my problem. If I run model.fit and save the weights, then I start a brand new model object (with the same input parameters as my first model), I want to be able to capture approximately the same dynamics using model.predict, without having to re-fit the model.

As far as I am aware I need to run model.fit to initialise the new model - my logic was to compile the new model (using the same parameters as the old trained model), then run the model.fit() for a short time period (eg. 1 epoch) just so that the model has been initialised. I then have loaded the recurrent and input weights from the original model into my new model and ran model.predict. Here I expected that, seeing as I am using the weights from the learned model, model.predict should just recreate the roughly the same dynamics as in the original model after learning - however i found that the dynamics look vastly different.

Im starting to think that I am misunderstanding model.fit and model.predict - I thought all the weight learning happens during model.fit() and then model.predict() should just generate the dynamics using the learnt weights. However, when I compile my new model and use model.fit() for a few epochs (just to initialise the model so I can use model.predict) and then load in the learnt weights from the original model, I find that the model.predict changes according to the number of epochs - as the epochs increases the dynamics also change (see below).

Screenshot 2023-03-23 at 10 37 29

Why should model.predict() in the new model be affected by the number of epochs in model.fit() if I am re-loading the weights from the original model? Is there learning occuring during model.predict? or perhaps it is not correctly using the newly assigned weights (I have used model.load_weights(checkpoint_path), and then checked esn_layer.recurrent_kernel and esn_layer.input_kernel and it seems to be correct)?

Thanks for the help! let me know if seeing some of my code would be helpful

lubin-liu commented 1 year ago

From your description about it seems like when you created the new model after loading in the weights, the new RNN layer passed into this new model did not use the final state from the previous model's layer (neuron pre-activation firing rates)? If so, then when you called model.fit the second time, the neuron firing rates will be starting from all zeroes (or randomly initialized, depending on the pre-defined get_initial_state method), leading to divergence in results. The RNN's layer states is technically not a weight so I don't think it's saved during model.save_weights, so it would have to be saved separately. Below is just a simplified version of the Zebrafish, with some dummy initial_a (random Gaussian) with noise removed and ran for 5 epochs:

capture1

I re-ran the same cell for only 2 epochs and saved the weights (error in the first two epoch matches the above):

capture2

The states of the RNN layers can be accessed via model.force_layer.states, the states being a tuple of elements shown here for ConstrainedNoFeedbackESN.

To test defining a new model with a new layer with the saved weights from the previous, the first state of the previous layer should be passed into the initial_a parameter of the new layer definition as below. You can actually manually build the model by calling the build method with the correct input shape (as opposed to fitting it again for 1 epoch), then load in the weights and either do fit or predict. When I did fit again for 3 epochs, the error matched the error in the last 3 epochs of the first figure. The changes are the 3 lines indicated below.

capture3

dmnburrows commented 1 year ago

This fixed it! Thank you so much :D