flatironinstitute / nemos

NEural MOdelS, a statistical modeling framework for neuroscience.
https://nemos.readthedocs.io/
MIT License
76 stars 7 forks source link

GLM predictor representation with pytrees #38

Open BalzaniEdoardo opened 1 year ago

BalzaniEdoardo commented 1 year ago

How do we represent neuron-specific model matrices?

One way is to use a tensor, $$X = [X]_{tij}$$ where t are the samples, $i$ are the parameter indices and $j$ are the neuron indices.

This works but most of the predictors will be shared (every experimental input will probably have the same set of basis for all the neurons, the coupling filters between a neuron all the other neurons will likely share the same basis). Is there a way to represent this efficiently?

Using a mask could be a way, a bit messier to implement, fitting a neuron at the time is another way, less efficient computationally but doesn't require creating the massive tensor.

BalzaniEdoardo commented 10 months ago

Another idea is to use pytrees, so that each neuron could have its own tree of parameters; this is more flexible and more in line with jax; we could use tree maps for evaluating the dot products with the model parameters

billbrod commented 9 months ago

As part of making it all compatible with pytrees, there are several discussions of input-related issues in #41:

BalzaniEdoardo commented 9 months ago

If we change to pytree we should modify the proximal operator for group lasso; the mask could be done using tree_map, which will simplify the code.

BalzaniEdoardo commented 9 months ago

I add a snippet of code that is helpful for computing operation on trees with different dept:

import jax

x = {
    "outer_1": {
        "inner_1": {"stim": jax.numpy.array([0]), "coupl": jax.numpy.array([1])},
        "inner_2": {"stim": jax.numpy.array([2]), "coupl": jax.numpy.array([3])}
        },
    "outer_2": {
        "inner_1": {"stim": jax.numpy.array([4]), "coupl": jax.numpy.array([5])},
        "inner_2": {"stim": jax.numpy.array([6]), "coupl": jax.numpy.array([7])}
    }
}

b = {"stim": jax.numpy.array([10]), "coupl": jax.numpy.array([20])}
res = jax.tree_util.tree_map(
    lambda subtree: jax.tree_util.tree_map(sum, subtree, b),
    x,
    is_leaf=lambda xsub: jax.tree_util.tree_structure(xsub) == jax.tree_util.tree_structure(b)
)
print(res)
BalzaniEdoardo commented 9 months ago

I thought of a way to reduce a deep tree to a sub-tree with given structure, I don't want to forget this, so I am posting here. The issue is similar to the one before, the solution is different and probably can be polished.

Let's say we have a nested tree, and another tree with a different depth, and we wish to perform an operation on the deeper tree, and after the operation is performed, the tree is reduced to the same structure of the shallower one.

A concrete example is: we have an output tree: y = {neuron1: data_y1, neuron2: data_y2...}, where data_yi are spike counts, and we have an input tree specific for each neuron, X = {input1: {neuron1: data_x1, neuron2: data_x2}, input2: {neuron3: data_x3, neuron4: data_x4}}.

The operation could be transforming the input into rate for each neuron ( a transformation on the leaf of X) and a tree_map operation over the leafs ys

Here is a code snippet that solves the performs the operation

import nemos.utils as utils
import jax

tree1 = {'aa': {'a': 1, 'b': 2}, 'bb': {'a': 3, 'b': 4},  'cc': {'a': 5, 'b': 6}}
tree2 = {"a":100, "b":1000}

# operation to be performed on leaves
operation_on_leaves = lambda y: y**2

# condition to asses which layer constitute a leaf
asses_leaves = lambda x: jax.tree_util.tree_structure(x) == jax.tree_util.tree_structure(tree2)

# this apply the operation, then flatten to a list of trees such that any 
# tree in the list has the same structure as tree2
flat_array = lambda x: jax.tree_util.tree_flatten(jax.tree_map(operation_on_leaves , x), 
    is_leaf= asses_leaves)[0]

# finally we can apply that performs the reduction step
# returning {‘a': 35, ‘b': 56}
print( jax.tree_map(lambda *x: sum(x), *flat_array(tree1)))