lfads / lfads-run-manager

Matlab interface for Latent Factor Analysis via Dynamical Systems (LFADS)
https://lfads.github.io/lfads-run-manager
Apache License 2.0
50 stars 28 forks source link

Not learning condition specific dynamics. Any suggestions? #32

Open didch1789 opened 2 years ago

didch1789 commented 2 years ago

Hi, I've been using your LFADS tools for about a month now, trying several parameters. Results look quite interesting, but main issue is that parameters I've set only learn the trial differences instead of condition specific dynamics. All trials are just making similar dynamics and just showing trials in order. (i.e., say I have 700 trials and color them in their order, color are graded in 1 - 700 order instead of reflecting any condition information).

I've check them in few ways as in your original paper: 1) Making t-sne plot from the initial condition. and 2) Neuronal level comparison. (condition averaged firing rates of real spike data vs. condition averaged firing rates from LFADS result), and 3) plotting the factors projected on to condition averaged PC axis as in your LFADS tutorial for Lorenz attractor. None of the result shows any condition specific information.

So I've tried several different parameters to solve the issue, but so far no good. I've changed

  1. Usage of alignment matrix (in the case of multi-session stitching. I've tried both single-session and multi-sessions)
  2. Number of batch size
  3. Dimensions of encoder and generator units.
  4. Dimensions of initial condition and factors.
  5. spike binsize

Well, it might be the case that the data I have have no condition specific dynamics to be learned, but other methods I've tried seems to say otherwise.

fyi, depending on the condition, each condition had 50-200 trials, I used data of 1000ms (would length of the trial matters?), and alignment matrix from principal component regression does seem to capture some condition specific dynamics. (However, factors from each session doesn't follow the global signal well, so I've set the model to learn the alignment matrix in all cases.)

Please let me know any suggestions you come up with. I'd be really grateful :) Thank you!

cpandar commented 2 years ago

Thanks for your note. I'd probably start with some basic debugging. Since your data has conditions, you could take the spiking activity (not the LFADS output) and compute peri-stimulus time histograms (PSTHs) for each neuron. They should hopefully show condition-dependent structure. Then you could do the same thing for the LFADS output and see how closely it is reproducing the PSTHs (or not).

Also, how wide is your binsize for the data? The smaller the bins you use, the longer the sequence length you are feeding into LFADS. So e.g. since you said your data is 1000 ms, if you are using 1 ms bins, then your sequence length is 1000. Longer sequences are harder to model and also take a much longer time to train the network. As a rough rule of thumb, we try to avoid sequences that are longer than 80-100 timepoints, and prefer even shorter if it is feasible for the data.

Generally we try to use as wide bins as we can that would still capture the underlying features/changes in the neural firing rates. You might try modeling the data with wide bins to start with, to allow you to iterate faster, and then try with smaller bins once you feel the model is working well. Similarly, you might try restricting the time window to the time you're most interested in to start with, to keep the sequence length short, and expand it further after the model is working well.

Good luck!

djoshea commented 2 years ago

Chethan hit all of the major things I would have suggested. The only thing that pops out is that you've said that the model outputs are structured by the trial order (presumably you mean the factor outputs or the initial conditions). I wonder how much real drift there is in your data. Is there actual real structure in the collected data, e.g., due to electrode drift? If that's the case LFADS would certainly try to explain that slowly-changing variation across trials. The PC regression tool computes trial averages (marginalizing over this drift), so it would still be able to find the condition-specific variance underneath. But LFADS might be only explaining trial-to-trial drift if it indeed the largest source of variance.

You might also try turning off learning of the readin matrices, which might help to preserve that condition-specific structure (i.e., set c_do_train_readin to false). The PC regression tool sets the initial value of these matrices, and if you turn off learning, they will be fixed to those values.

didch1789 commented 2 years ago

Thanks for the fast and kind reply!

1) I haven't seen all neurons' firing rates, but I've compared the condition-averaged PSTHs of each neuron that showed high correlation with LFADS output (inferred rates, not factors). There were a few neurons that showed different activity profile depending on the condition, but LFADS output showed just same neuronal profile across conditions. 2) I've also tried several binsizes (including 2, 5, and 10ms, which results in 500, 200, and 100 bins), and results haven't showed big differences. 3) I've also tried both cases, setting c_do_train_readin to true and false, which again showed no big differences.

It seems Dan's idea of drifting might be crucial. And setting c_do_train_readin to false might reduce the effect, but as seen from my result, that doesn't seem to reduce the trial-by-trial variances as far as in my data. Do you think using "ridgeRegressGlobalPCs" for regression might help?

There's actually one more thing, do you have any rule of thumb for "number of trials per condition?". It seems that if I were to use alignment matrix, process of getting initial regression subspace from each session is quite important. So I am bit curious about 1) appropriate number of trials per condition , and 2) how the results will be affected (in principle) by usage of initial alignment matrix.

I am recently spending quite a time dealing with your cool methods. Thanks again for the fast and kind reply :)

cpandar commented 2 years ago

Hmm, I think I'd have to see what the PSTHs look like - something like the single neuron plots we showed in Fig. 2 of the 2018 LFADS paper

I'm a little confused as to why a read-in matrix is needed - only for multi-session, right? How many neurons do you have in a single session? If you have a reasonable number of neurons in a given session (let's say 40 or above), I'd focus on getting single-session LFADS to work before messing with multi-session