flatironinstitute / nemos

NEural MOdelS, a statistical modeling framework for neuroscience.
https://nemos.readthedocs.io/
MIT License
82 stars 8 forks source link

Architecture of basis/GLM classes #9

Closed billbrod closed 6 months ago

billbrod commented 1 year ago

Might be useful to restructure the architecture of basis and GLM classes.

We could make basis more like a scikit-learn transformer, with a fit (current gen_basis_funcs), transform (the convolution with spike train that currently happens in GLM, or sampling in space as an animal moves through an environment) and fit_transform (the two together) methods.

Then what arguments would GLM accept? Probably the output of fit_transform (rather than the basis object itself), and that should be passed during initialization.

Additionally:

gviejo commented 1 year ago

Here is a proposition for the general workflow :

# Maybe some additional preprocessing modules like time warping, zscoring, etc

# Defining the basis
basis1 = RaisedCosineBasis(window_size = 100, n_basis = 5)
basis1.fit_transform(spikes)
basis2 = MyFancyBasis()
basis2.fit_transform(position)

# Fitting the glm
glm = GLM((basis1, basis2))
glm.add_basis(basis3)
glm.fit(spikes)
ahwillia commented 1 year ago

This would be my vision for the GLM of conjunctive position + head direction coding.

# 1d basis with periodic boundary condition.
head_direction_basis = RaisedCosineBasis(
    support=(0, 2*pi),
    n_basis=5,
    periodic=True
) 

# 2d basis with no periodic boundary.
position_basis = MSplineBasis(
    support=((0, 100), (0, 200)), # 100cm x 200cm box
    n_basis=(10, 20)) # 10 x 20 grid of basis functions,
    spline_order=3,
)

# Construct cartesian product of basis funcs.
joint_basis = ProductBasis((position_basis, head_direction_basis))

# Apply basis functions.
X = joint_basis.fit_transform((position_measurements, head_direction_measurements))

# Fit GLM model.
glm = GLM()
glm.fit(X, y)   # y = vector of spikes for one neuron

Also we could consider using sklearn.pipeline.Pipeline.

billbrod commented 1 year ago

transform might do different things for different domains: you'll convolve spikes in time and visual input, but not position (just evaluate at certain locations). How to handle that? Different methods or completely different classes for different domains? -- if it just evaluates, that's basically what our current gen_basis_funcs does, for properly-chosen sample_pts

So probably: not use language of fit/transform, users will just initialize the basis and then call EITHER evaluate or convolve. Output of that is a multi-dimensional array that they pass to GLM

billbrod commented 1 year ago

Basis's __init__ should accept number of basis functions and any other hyper-parameters (e.g., MSPline order, whether raised cosine basis is linear or log). Then the user will call either evaluate(samples) or convolve(samples, window_size) to create the model matrix. All bases will also have a gen_basis_funcs(window_size) which gets called by convolve and can also be called separately by the user for visualization purposes. And there's a shared _basis() (or better name) method that has the actual math in it.

How to compose bases:

head_pos_basis = MSpline(k)
maze_location_basis = BSpline(k)
combined_basis = BasisProduct(head_pos_basis, maze_location_basis)
# this just calls BasisProduct under the hood
combined_basis = head_pos_basis * maze_location_basis 
# and similar for + or BasisSum
combined_matrix = combined_basis.evaluate((head_locs, maze_locs))

All names subject to change.

GLM:

We also need ways to visualize the fit GLM, but that is out of scope for this issue.

BalzaniEdoardo commented 1 year ago

Comments on the architecture we decided for the basis:

BalzaniEdoardo commented 1 year ago

I think we will still need an basis handler object that sit on top of the basis function hierarchically for the following reasons:

BalzaniEdoardo commented 1 year ago

Basis's __init__ should accept number of basis functions and any other hyper-parameters (e.g., MSPline order, whether raised cosine basis is linear or log). Then the user will call either evaluate(samples) or convolve(samples, window_size) to create the model matrix. All bases will also have a gen_basis_funcs(window_size) which gets called by convolve and can also be called separately by the user for visualization purposes. And there's a shared _basis() (or better name) method that has the actual math in it.

How to compose bases:

head_pos_basis = MSpline(k)
maze_location_basis = BSpline(k)
combined_basis = BasisProduct(head_pos_basis, maze_location_basis)
# this just calls BasisProduct under the hood
combined_basis = head_pos_basis * maze_location_basis 
# and similar for + or BasisSum
combined_matrix = combined_basis.evaluate((head_locs, maze_locs))

All names subject to change.

GLM:

* separate classes for different noise models, e.g. Poisson and Gamma (which inherit GLM)

* `__init__`  accepts optimization-related arguments: solver, solver args, link function

* `fit` accepts `model_matrix` (as returned by `basis.evaluate` or `basis.convolve`; always 2d, time points by basis functions ) and `spike_data` (time points by neurons) and follows sklearn's [estimator API](https://scikit-learn.org/stable/developers/develop.html#apis-of-scikit-learn-objects)

* if you want to get GLM's predictions for new data (e.g., train and test set), you call `predict` (or `transform`? what does sklearn's cross-validation expect?), with a new `model_matrix` constructed from the same basis objects. (matrix multiplication running serves as sanity checks: number of weights matches number of basis functions, etc). This returns the firing rate for observed spikes that were not used for fitting.

* `simulate` simulates new spikes based on some initializing activity and the model matrix (can be the one used for fitting or a new one).

  * Edoardo has another method for doing this based on the fact that inter-spike interval (with Poisson noise) is exponentially-distributed and so you can compute the time of the next spike starting from no spikes. This has the advantage of not requiring initializing activity and also being theoretically matched to the proper distribution (assuming enough samples). -- actually, Guillaume disagrees that this is what actual neurons' inter-spike interval looks like, so need to double-check that this is true. @ahwillia do you know?

* `score` is used to evaluate the quality of fit, it's the thing that optimization is minimizing.

  * optimization should always use log-likelihood but users might want to use something else e.g., to pseudo-$R^2$, for evaluation purposes. We should provide this (and similar functionality) in separate methods that accept the predicted and actual spike rates.

We also need ways to visualize the fit GLM, but that is out of scope for this issue.

The implementation of the Basis object deviates from what proposed here in two ways:

  1. it requires to define the operation (convolve. vs evaluate) at basis definition
  2. it has a single method gen_basis, which receives as input some samples, and returns the basis applied to the input;

For what it concern (1), there is a substantial advantage in terms of user experience (a user can combine basis without keeping track of which operation has to be performed on which input, because it will be decided and specified upfront). As a minor point, in terms of code readability it allows the generation of the model matrix with a clean single line recursion.

(2) is needs to be slightly changed because it would be important for the user to have a way to easily inspect the basis themself and not how the basis is applied to the samples. This is not a problem when the basis is evaluated. To inspect the basis one can call gen_basis with a linspace spanning the samples range, for example for the 1d case:

plt.plot(basis.gen_basis(np.linspace(samples.min(), samples.max(), 1000)).T)

for the convolve, the gen_basis does perform the convolution, so one would need pass an impulse (zeros with a 1 in the middle) to see how the basis used for the convolution looks like.

for more complex basis, like mix of convolve and evaluate I guess it should be a mix of calling gen_basis on impulses and linspaces...

Overall, we should have a dedicated inspect_basis function that calls gen_basis appropriately. Needs further discussion

BalzaniEdoardo commented 1 year ago

Structural changes required: right now the definition of the basis type (convolve, evaluate, add, mult) is done through a string that is passed in the initialization. we should change that into a bunch of different sub-classes, one per basis type, which allows cleaner integration of new specific basis type that might handle inputs of very different type (images for example that may require a 2D convolution)

ahwillia commented 1 year ago

Catching up on a few things...

Edoardo has another method for doing this based on the fact that inter-spike interval (with Poisson noise) is exponentially-distributed and so you can compute the time of the next spike starting from no spikes. This has the advantage of not requiring initializing activity and also being theoretically matched to the proper distribution (assuming enough samples). -- actually, Guillaume disagrees that this is what actual neurons' inter-spike interval looks like, so need to double-check that this is true. @ahwillia do you know?

For a homogeneous Poisson process (i.e. firing rate that is constant over time) the inter spike intervals will be exponentially distributed. But things get more challenging for inhomogeneous Poisson processes with time-varying rates (which is what we have in a typical GLM). We probably want to stay in discrete time to keep things simple.

ahwillia commented 1 year ago

Just throwing this out there, but I am not sure we need anything other than evaluate to be implemented by the basis functions, since convolution can be implemented in a few lines of additional code, which we could package inside a function. If I understand correctly, this would substantially simplify the basis function API as users would not have to specify at initialization time whether they are evaluating or convolving the basis.

In my mind we do not want users implementing the convolution themselves so all of this will be hidden from them. All the users need to do is create the basis instance and pass it off to whatever model we provide.

Also I am not sure I see an applications for any convolution higher than 4D. It might be prudent to start by limiting convolutions to 1D or 2D as this would cover the majority of applications.

# Demo for 1d basis convolution
basis = RaiseCosineBasis(k)
filters = basis.evaluate(np.arange(0, window_size, timebin_size))
convolve_1d_basis(filters, time_series)  # implemented in utils
# Demo for 2d basis convolution
basis = BSplineBasis(k) * BSplineBasis(k)
xx = np.arange(0, window_size, binsize)   # discretize first dimension
yy = np.arange(0, window_size, binsize)   # discretize second dimension
filters = basis.evaluate(
    np.row_stack((np.repeat(xx, yy.size), np.tile(yy, xx.size)))
).reshape((xx.size, yy.size, k))
convolve_2d_basis(filters, image_data)  # not implemented yet
# ND convolution?
basis = BSplineBasis(k) ** num_dims
g = np.arange(0, window_size, binsize)
eval_pts = np.column_stack([z.ravel() for z in np.meshgrid(*[g for _ in range(num_dims)])])
filters = basis.evaluate(eval_pts).reshape([g.size for _ in range(num_dims)] + [k])
convolve_nd_basis(filters, nd_tensor_data)  # not implemented yet
BalzaniEdoardo commented 1 year ago

I think that's a good idea. A dedicated module that runs convolutions (1D and 2D first), it would simplify our classes and can be implemented so that it takes care of the presence of a trial structure while running the convolution.

Question: should we the convolve_nd_basis() accept a pynapple TimeSeries and IntervalSet and a numpy.array equivalent of that stricture? I personally think it would be a nice general format, it would not require pynapple but it would be compatible with it

ahwillia commented 1 year ago

should we the convolve_nd_basis() accept a pynapple TimeSeries and IntervalSet and a numpy.array equivalent of that stricture

Perhaps it should only accept jax arrays for performance. I don't think users will call / use this convolve function. It will just be for our internal use for fitting models.

BalzaniEdoardo commented 1 year ago

All sounds good. After discussing with Billy we opted to add another public method for the Basis object (other then evaluate) that is used for exploring the basis function.

The method will have as input the window size, and will return the basis evaluated in a linspace/meshgrid of window size length covering the domain of the basis function.

Summarizing with a list of TODOs would be,

ahwillia commented 1 year ago

Sounds good. I think I'm still not convinced we need a "get_basis" function. But please forge ahead with what you think is best!

On Tue, Jun 20, 2023, 1:25 PM Edoardo Balzani @.***> wrote:

All sounds good. After discussing with Billy we opted to add another public method for the Basis object (other then evaluate) that is used for exploring the basis function.

The method will have as input the window size, and will return the basis evaluated in a linspace/meshgrid of window size length covering the domain of the basis function.

Summarizing with a list of TODOs would be,

  • get rid the 'evaluate' vs 'convolve' kwarg
  • replace get_model_matrix with evaluate
  • add an "get_basis" (any opinion/suggestion for the name of this method) that gets window size and returns the basis evaluated in a grid of points covering the domain
  • create an function for 1d and 2d convolution that receives jax arrays (TxNxD, T=num samples, N=num trials, D =1 or two depending on the convolution type) and returns the convolve. kwargs: 'full', 'same', 'valid'?

— Reply to this email directly, view it on GitHub https://github.com/flatironinstitute/generalized-linear-models/issues/9#issuecomment-1599215597, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAE3NUPL4L5OKWAAFFDYGJTXMHMG5ANCNFSM6AAAAAAXPOL474 . You are receiving this because you were mentioned.Message ID: @.*** com>

billbrod commented 1 year ago

The idea with get_basis is it calls evaluate with a specific argument. This is what gets used in the convolution function, and also should be called if the user wants to visualize their set of bases. For most bases, this will be np.arange(0, window_size), but not all (not the raised cosine basis, for example), and our rationale is we'll want to provide a convenient way to do this correctly.