Open BalzaniEdoardo opened 3 weeks ago
There was an example of a rank-deficient matrix in plot_06_calcium_imaging.py
, but this was removed in #247. pasting the code here in case it's useful
# %%
# To ensure the computation accounts for the model's intercept term,
# we can add a constant column of ones before calculating the rank.
# Below is a utility function for adding the intercept column.
def add_intercept(X):
"""Add an intercept term to design matrix.
Convert matrix to float64, drops nans and add intercept term.
"""
# convert to float64 for rank computation precision
X = np.asarray(X, dtype=np.float64)
# drop nans
X = X[nmo.tree_utils.get_valid_multitree(X)]
return np.hstack([np.ones((X.shape[0], 1)), X])
print(f"Number of features: {X.shape[1] + 1}") # num coefficients + intercept
print(f"Matrix rank: {np.linalg.matrix_rank(add_intercept(X))}")
# %%
# By setting,
w = np.ones((X.shape[1] + 1))
w[0] = -1
w[1 + heading_basis.n_basis_funcs:] = 0
# %%
# We have that,
np.max(np.abs(np.dot(add_intercept(X), w)))
# %%
# This implies that there will be infinite different parameters that results in the same firing rate,
# or equivalently there will be infinite many equivalent solutions to an un-regularized GLM.
# define some random coefficients
coef = np.random.randn(X.shape[1] + 1)
# the firing rate is softplus([1, X] * coef)
# adding w to the coefficients does not change the output rate.
firing_rate = jax.nn.softplus(np.dot(add_intercept(X), coef))
firing_rate_2 = jax.nn.softplus(np.dot(add_intercept(X), coef + w))
# check that the rate match
np.allclose(firing_rate, firing_rate_2)
When we add this, remember to link to it from the ## Design matrix
section of tutorials/plot_06_calcium_imaging.py
Background note on identifiability of model including:
Tutorial on Basis and identifiability