flatironinstitute / nemos

NEural MOdelS, a statistical modeling framework for neuroscience.
https://nemos.readthedocs.io/
MIT License
82 stars 8 forks source link

Add documentation about model identifiability #259

Open BalzaniEdoardo opened 3 weeks ago

BalzaniEdoardo commented 3 weeks ago
billbrod commented 2 weeks ago

There was an example of a rank-deficient matrix in plot_06_calcium_imaging.py, but this was removed in #247. pasting the code here in case it's useful

# %%
# To ensure the computation accounts for the model's intercept term,
# we can add a constant column of ones before calculating the rank.
# Below is a utility function for adding the intercept column.

def add_intercept(X):
    """Add an intercept term to design matrix.

    Convert matrix to float64, drops nans and add intercept term.
    """
    # convert to float64 for rank computation precision
    X = np.asarray(X, dtype=np.float64)
    # drop nans
    X = X[nmo.tree_utils.get_valid_multitree(X)]
    return np.hstack([np.ones((X.shape[0], 1)), X])

print(f"Number of features: {X.shape[1] + 1}")  # num coefficients + intercept
print(f"Matrix rank: {np.linalg.matrix_rank(add_intercept(X))}")

# %%
# By setting,

w = np.ones((X.shape[1] + 1))
w[0] = -1
w[1 + heading_basis.n_basis_funcs:] = 0

# %%
# We have that,

np.max(np.abs(np.dot(add_intercept(X), w)))

# %%
# This implies that there will be infinite different parameters that results in the same firing rate,
# or equivalently there will be infinite many equivalent solutions to an un-regularized GLM.

# define some random coefficients
coef = np.random.randn(X.shape[1] + 1)

# the firing rate is softplus([1, X] * coef)
# adding w to the coefficients does not change the output rate.
firing_rate = jax.nn.softplus(np.dot(add_intercept(X), coef))
firing_rate_2 = jax.nn.softplus(np.dot(add_intercept(X), coef + w))

# check that the rate match
np.allclose(firing_rate, firing_rate_2)
billbrod commented 2 weeks ago

When we add this, remember to link to it from the ## Design matrix section of tutorials/plot_06_calcium_imaging.py