AllenInstitute / bmtk

Brain Modeling Toolkit
https://alleninstitute.github.io/bmtk/
BSD 3-Clause "New" or "Revised" License
265 stars 86 forks source link

The cosine bump basis functions used in FilterNet seem to be different from the formula in Pillow et.al. J Neurosci 2005 #334

Open CloudyDory opened 10 months ago

CloudyDory commented 10 months ago

The consine bump basis function of LGN cells in Pillow et.al. J Neurosci 2005 is:

B(t) = (cos(log[t+tau] - phi) + 1) / 2,  if phi-pi < log[t+tau] < phi+pi

However, the actual function makeBasis_StimKernel implemented in bmtk/simulator/filternet/lgnmodel/fitfuns.py (from line 28-81 and line 97-122) seems to be more complex than that and lacks documentation (the tutorial in bmtk/docs/tutorial/07_filter_models.ipynb is not enough). For example, the values of variable b (line 33), variable ylim (line 37), and variable db (line 40) are all hard-coded, do not appear in the original equation, and I can't find an explanation on what they mean. Besides, I am also not sure which variable in the original equation does kpeaks represent, and why there is an extra normalization step.

There is also a demo in bmtk/docs/tutorial/helpers/filternet_images_helpers.py that draws a figure of basis function with different parameters. I am able to reproduce that figure with a much simpler code that summarize the actual computation bmtk performs. The computation in get_temporal_kernel() in the following code is indeed different from the original equation in Pillow et.al. J Neurosci 2005. So, what is the reason that bmtk uses a different formula?

import numpy as np
import matplotlib.pyplot as plt

#%% Helper functions
def get_temporal_kernel(t, kpeaks, delay, weights):
    '''
    Inputs:
        t: [length,1] array
        kpeaks: [1,2] array
        delay: [1,2] array
        weights: [2,] array
    Output:
        kernel: [length,] array
    '''
    log_t = np.log(t+1.3-delay) - np.log(kpeaks)  # Use 1.3 not 0.3 here to compensate for the extra 1-point offset in line 65 of `fitfuns.py`.
    log_t_pi = np.clip(np.pi*log_t/get_temporal_kernel.db2, -np.pi, np.pi)
    log_t_pi[np.isnan(log_t_pi)] = -np.pi
    basis = (np.cos(log_t_pi) + 1) / 2.0
    basis_norm = basis / np.linalg.norm(basis, ord=2, axis=0)
    kernel = basis_norm @ weights
    return kernel
get_temporal_kernel.db2 = 2.0 * np.diff(np.log([100.3, 200.3]))  # 2*db in `fitfuns.py`

#%% Plot kernels with different parameters
weights = np.array([[30.0, -20.0], [30.0, -1.0], [15.0, -20.0]])
kpeaks = np.array([[3.0, 5.0], [3.0, 30.0], [20.0, 40.0]])
delays = np.array([[.0, 0.0], [0.0, 60.0], [20.0, 60.0]])

t = np.expand_dims(np.arange(0,150,1), axis=1)  # milliseconds

fig, axes = plt.subplots(3, 3, figsize=(10, 7))
ri = ci = 0

for ci in range(weights.shape[0]):
    kernel = get_temporal_kernel(t, np.array([[9.67, 20.03]]), np.array([[0.0, 1.0]]), weights[ci,:])
    idx = np.abs(kernel) > 0.0
    axes[ri, ci].plot(-t[idx]/1000, kernel[idx])
    axes[ri, ci].set_ylim([-3.5, 10.0])
    axes[ri, ci].text(0.05, 0.90, 'weights={}'.format(weights[ci,:]), horizontalalignment='left', verticalalignment='top', transform=axes[ri, ci].transAxes)
axes[0, 0].set_ylabel('effect of weights')
ri += 1

# kpeaks parameters controll the spread of both peaks, the second peak must have a bigger spread
for ci in range(kpeaks.shape[0]):
    kernel = get_temporal_kernel(t, kpeaks[[ci],:], np.array([[0.0, 1.0]]), np.array([30.0, -20.0]))
    idx = np.abs(kernel) > 0.0
    axes[ri, ci].plot(-t[idx]/1000, kernel[idx])
    axes[ri, ci].set_xlim([-0.15, 0.005])
    axes[ri, ci].text(0.05, 0.90, 'kpeaks={}'.format(kpeaks[ci,:]), horizontalalignment='left', verticalalignment='top', transform=axes[ri, ci].transAxes)
axes[1, 0].set_ylabel('effects of kpeaks')
ri += 1

for ci in range(delays.shape[0]):
    kernel = get_temporal_kernel(t, np.array([[9.67, 20.03]]), delays[[ci],:], np.array([30.0, -20.0]))
    idx = np.abs(kernel) > 0.0
    axes[ri, ci].plot(-t[idx]/1000, kernel[idx])
    axes[ri, ci].set_xlim([-0.125, 0.001])
    axes[ri, ci].text(0.05, 0.90, 'delays={}'.format(delays[ci,:]), horizontalalignment='left', verticalalignment='top', transform=axes[ri, ci].transAxes)
axes[2, 0].set_ylabel('effects of delays')

fig.show()

Besides, the function ff() in line 108-117 of bmtk/simulator/filternet/lgnmodel/fitfuns.py loops over an numpy array one-by-one, but this can be easily vectorized. Why does it choose a slower computation?