openworm / neuronal-analysis

Tools to produce, analyse and compare both simulated and recorded neuronal datasets
MIT License
4 stars 4 forks source link

Recurrent neural networks #10

Closed BenjiJack closed 8 years ago

BenjiJack commented 8 years ago

Hi all,

So I have been thinking for a while that artificial neural networks could be a useful tool for analyzing neuronal time series data like that presented by Kato et al. It seems like a reasonable approach to simulate a real neural net with an artificial one.

I am sure you have seen examples of text (e.g. fake newspaper articles, fake works of Shakespeare) generated by recurrent neural networks. I was inspired by a blog post I read over the weekend on this topic to train a recurrent neural net using Kato's data, and to then generate simulated neural timeseries data using the trained model.

In theory, if we could train the neurons of a simulated neural net to generate output like that in the real neural net, would the model not be "thinking?"

I trained the RNN separately on each of the 5 Kato trials, and then ran PCA on the output of the RNN. I plotted the PCA of the real data next to the PCA of the simulated data for each of the 5 trials.

Here's an example of 2 trials. 1000 generated, simulated data points are shown for each trial. screen shot 2016-04-17 at 3 29 24 pm

Frankly, the results at this stage are unimpressive! It looks like the simulation may be getting stuck in local minima. I have not yet tried tinkering with the hyperparameters of the model. I also did not yet try normalizing the data. I think there is still low hanging fruit to improve the results.

Even though there is not yet much to show, the code runs and seems to do what I intend for it to do, so I wanted to share it with you. (Trying to share early and often!) I am new to RNNs, so maybe one of you has more experience and can suggest some next steps to tune the model. The code is relatively simple and uses the Python ML library Keras, so it should be easy for you to iterate further.

I ultimately hope to run the neural net further "upstream" in the Kato pipeline, as we should really be simulating the firing of the neurons themselves, and not the post-processing derivative of the time series as I have done thus far.

jrieke commented 8 years ago

@BenjiJack Just happened to see this - not working actively on OpenWorm any more, but thought I'll quickly share some thoughts as I've been working on a project in the same area lately (in my case, I generated time series that contain the positions of biological cells - this repo).

Your observation that the RNN gets stuck in an equilibrium is correct. This is totally normal if you predict the values at the next time step directly. Instead, the best solution is to let your network predict the parameters of a probability distribution (namely a Gaussian mixture distribution), from which you can then sample the values at the next time step.

I don't have much time right now, so I'll just quickly give you some references: For starter, have a look at section 5 of my project report for a quick introduction and section 4 of Graves 2013 for some more details. Also, Bishop 1994 is quite helpful; it's the paper that introduced this technique of predicting a probability distribution. Eventually, you need an additional layer on your network plus a custom loss function, so that you can predict the Gaussian mixture distribution. I've implemented this for my project in keras here - the code is a bit unstructured, look for # Layer implementing a Gaussian mixture model. and # Set up the network.

If you have questions, let me know; I may also throw in some more comments during the next days. I also wanted to refactor my code at some point to allow arbitrary time series, I'll let you know if I have any progress on that (but don't expect anything too soon...).

jrieke commented 8 years ago

Two more points:

1) I looked through your code, seems like you train the network on (overlapping) chuncks with length maxlen (like in the default keras example). Better way is to set stateful=True in the LSTM layers and then feed in time step after time step to the network (due to the statefulness, the LSTM layers keep their state in between). This is both more efficient and enables the network to learn longer dependencies. For an implementation, look again in my code, or search through the keras docs for some instructions.

2) In case you want to run my code: Unfortunately I could not include the data in the repo, but I can send you the data privately if you want it.

lukeczapla commented 8 years ago

Benji, could you share some more details about your approach? I am curious how you obtain a simulated value of the fluorescence derivative. We'd like to pursue more of what @jrieke just mentioned in his work - I am looking over the rnn.py example but more interested in the details of the simulation right now (working with jNeuroML). @theideasmith : do you want to go over the scope diagram and propose some additions?

BenjiJack commented 8 years ago

@jrieke @lukeczapla @theideasmith Thank you for your comments and for taking a look at my code.

@jrieke

@lukeczapla Basically what I've done so far is write the function rnn.generate, which takes timeseries data, trains a RNN on that data, and attempts to generate new, similar data. The function is agnostic as to which particular timeseries it is fed.

In the Jupyter notebook, I feed the rnn.generate function Kato's fluorescence derivative data, obtain output from the RNN, and run PCA on the output. I do that 5 times, one for each trial.

At this stage, I am using an off-the-shelf RNN. The configuration of my RNN is identical to the one provided on the tutorial in the Keras documentation on LSTMs. I am sure much more can be done to tune the RNN to our purposes, and @jrieke has given us some great places to start.

theideasmith commented 8 years ago

I'll respond in length after spring break begins on Wednesday, but this is extremely valuable and novel. If NN's truly produces theoretically biologically reproducible results, we'll now have a whole lot more data to work with.

I'm just wondering if you think these NN-generated datasets will say anything new about dynamics or really just be a mashup of existing motifs in the Kato data. Because the NN's are only trained on a particular subset of C Elegans neuronal activity, what kind of new information would they give us not in the datasets we already have? See interesting post for what inspired this question.

These questions are only relevant to a particular use case of NN's here and not at all intended to challenge idea of using them. There are many other ways this work can be useful.

jrieke commented 8 years ago

@BenjiJack All RNNs that generate new data need some form of sampling from a probability distribution. If you look at Kaparthy's blog post that you linked above, you can see that he mentions a "temperature" parameter - this means nothing else but that he samples from a probability distribution (only in this case it's a discrete distribution over the possible characters). The temperature controls how "random" the sampling is (i.e. if temperature = 0 you always pick the most likely character). In fact, you can run Kaparthy's code with temperature = 0 and you observe that the RNN always outputs the same character (i.e. it settles into an equilibrium).

@theideasmith In general, the generated shouldn't contain anything "new" (it only learns on the original data, so it can't really contain more features than that). Some people use RNNs to generate more data for analysis - I don't really know about the results, but I'd be rather skeptical of it. See also the next two minutes in Alex Graves' talk here. (As a non-representative side-note: In the project I did (generating data of biological cell movement) I found that the generated data reproduced the main properties of the original data, but had less variance, i.e. it was actually "worse" than the original data.)

theideasmith commented 8 years ago

@jrieke what technique did you use to determine less variance in generated data? PCA?

jrieke commented 8 years ago

@theideasmith I used a specific statistical analysis (in short: modelling the data as random walks and extracting some properties of this random walk). You can read more about it on p. 12 of my report here - especially look at fig. 10 for a comparison of the real and the generated data.