lfads / lfads-run-manager

Matlab interface for Latent Factor Analysis via Dynamical Systems (LFADS)
https://lfads.github.io/lfads-run-manager
Apache License 2.0
50 stars 29 forks source link

Trained RNNs Parameters Extraction #33

Closed omaganatellez closed 2 years ago

omaganatellez commented 2 years ago

Hello I have trained a run of the LFADS algorithm to a particular dataset. I am trying to examine the parameters and weights of the Encoder and Generator RNNs. Is there a way to obtain them from the trained model? Are they saved in the Run folder?

djoshea commented 2 years ago

Hi there, yes you should be able to access them as described here https://lfads.github.io/lfads-run-manager/trained-params/

You can use the Matlab run manager to access them easily as properties of the LFADS.ModelTrainedParameters class, or you can load them yourself from the lfadsOutput/model_params hd5 file. Let me know if this works for you in which case we can close the issue.

omaganatellez commented 2 years ago

Yes, it works for me, Thank you very much for the quick response. If may I ask a similar question. Is there a way to run new data through a trained LFADS network?

djoshea commented 2 years ago

Great! Yes, this is possible, although it's not implemented in the Matlab code, you can run the commands or edit the shell scripts directly to accomplish this. There are two approaches depending on the situation you are in.

If you have new data from the exact same neurons, then all you need to do is sample and average the posterior means again, but with new input data. See https://github.com/tensorflow/models/blob/master/research/lfads/run_lfads.py#L312 on how this is done. Essentially you'd be swapping out the input data after the model is trained and re-evaluating it.

If you have new data from different neurons in a stitching context, you'd need to retrain the input / output alignment (readin and readout matrices) as well, which can be configured using the --do-train-io-only flag, and then subsequently evaluate the model to get the posterior means. This flag is documented in the code here https://github.com/tensorflow/models/blob/master/research/lfads/run_lfads.py#L312 but I have not implemented it in the Matlab run manager unfortunately.