Open BalzaniEdoardo opened 1 year ago
Another idea is to use pytrees, so that each neuron could have its own tree of parameters; this is more flexible and more in line with jax; we could use tree maps for evaluating the dot products with the model parameters
As part of making it all compatible with pytrees, there are several discussions of input-related issues in #41:
params
to be a tuple of length 2, why not just separate that into two explicit arguments?_check_input_dimensionality
allows for x and y to be None
, which seems weird. It's also GLM-specific right now, so maybe move it?_check_and_convert_params
If we change to pytree we should modify the proximal operator for group lasso; the mask could be done using tree_map, which will simplify the code.
I add a snippet of code that is helpful for computing operation on trees with different dept:
import jax
x = {
"outer_1": {
"inner_1": {"stim": jax.numpy.array([0]), "coupl": jax.numpy.array([1])},
"inner_2": {"stim": jax.numpy.array([2]), "coupl": jax.numpy.array([3])}
},
"outer_2": {
"inner_1": {"stim": jax.numpy.array([4]), "coupl": jax.numpy.array([5])},
"inner_2": {"stim": jax.numpy.array([6]), "coupl": jax.numpy.array([7])}
}
}
b = {"stim": jax.numpy.array([10]), "coupl": jax.numpy.array([20])}
res = jax.tree_util.tree_map(
lambda subtree: jax.tree_util.tree_map(sum, subtree, b),
x,
is_leaf=lambda xsub: jax.tree_util.tree_structure(xsub) == jax.tree_util.tree_structure(b)
)
print(res)
I thought of a way to reduce a deep tree to a sub-tree with given structure, I don't want to forget this, so I am posting here. The issue is similar to the one before, the solution is different and probably can be polished.
Let's say we have a nested tree, and another tree with a different depth, and we wish to perform an operation on the deeper tree, and after the operation is performed, the tree is reduced to the same structure of the shallower one.
A concrete example is: we have an output tree: y = {neuron1: data_y1, neuron2: data_y2...}, where data_yi are spike counts, and we have an input tree specific for each neuron, X = {input1: {neuron1: data_x1, neuron2: data_x2}, input2: {neuron3: data_x3, neuron4: data_x4}}.
The operation could be transforming the input into rate for each neuron ( a transformation on the leaf of X) and a tree_map operation over the leafs ys
Here is a code snippet that solves the performs the operation
import nemos.utils as utils
import jax
tree1 = {'aa': {'a': 1, 'b': 2}, 'bb': {'a': 3, 'b': 4}, 'cc': {'a': 5, 'b': 6}}
tree2 = {"a":100, "b":1000}
# operation to be performed on leaves
operation_on_leaves = lambda y: y**2
# condition to asses which layer constitute a leaf
asses_leaves = lambda x: jax.tree_util.tree_structure(x) == jax.tree_util.tree_structure(tree2)
# this apply the operation, then flatten to a list of trees such that any
# tree in the list has the same structure as tree2
flat_array = lambda x: jax.tree_util.tree_flatten(jax.tree_map(operation_on_leaves , x),
is_leaf= asses_leaves)[0]
# finally we can apply that performs the reduction step
# returning {‘a': 35, ‘b': 56}
print( jax.tree_map(lambda *x: sum(x), *flat_array(tree1)))
How do we represent neuron-specific model matrices?
One way is to use a tensor, $$X = [X]_{tij}$$ where t are the samples, $i$ are the parameter indices and $j$ are the neuron indices.
This works but most of the predictors will be shared (every experimental input will probably have the same set of basis for all the neurons, the coupling filters between a neuron all the other neurons will likely share the same basis). Is there a way to represent this efficiently?
Using a mask could be a way, a bit messier to implement, fitting a neuron at the time is another way, less efficient computationally but doesn't require creating the massive tensor.