Neuroinflab / kCSD-python

Kernel Current Source Density
https://kcsd-python.readthedocs.io/en/latest/
Other
15 stars 23 forks source link

Do not use matrix inversion #95

Closed abukaj closed 5 years ago

abukaj commented 5 years ago

Matrix inversion is considered numerically harmful, but KCSD.values() uses it.

I have cooked a ground truth CSD which demonstrates the issue (code below).

Figure_4 Figure_1

Figure_5 Figure_2

Figure_6 Figure_3

The pictures has been generated with the following code:

import numpy as np
import matplotlib.pyplot as plt
from kcsd import KCSD2D, utility_functions, csd_profile
from scipy import integrate, linalg
import matplotlib.cm as cm

def make_plot(xx, yy, zz, title, cmap=cm.bwr):
    fig = plt.figure(figsize=(7, 7))
    ax = plt.subplot(111)
    ax.set_aspect('equal')
    t_max = np.max(np.abs(zz))
    levels = np.linspace(-1 * t_max, t_max, 32)
    im = ax.contourf(xx, yy, zz, levels=levels, cmap=cmap)
    ax.set_xlabel('X (mm)')
    ax.set_ylabel('Y (mm)')
    ax.set_title(title)
    ticks = np.linspace(-1 * t_max, t_max, 3, endpoint=True)
    plt.colorbar(im, orientation='horizontal', format='%.2g', ticks=ticks)
    return ax

def integrate_2d(csd_at, true_csd, ele_pos, h, csd_lims):
    csd_x, csd_y = csd_at
    xlin = csd_lims[0]
    ylin = csd_lims[1]
    Ny = ylin.shape[0]
    m = np.sqrt((ele_pos[0] - csd_x)**2 + (ele_pos[1] - csd_y)**2)
    m[m < 0.0000001] = 0.0000001
    y = np.arcsinh(2 * h / m) * true_csd
    integral_1D = np.zeros(Ny)
    for i in range(Ny):
        integral_1D[i] = integrate.simps(y[:, i], ylin)

    integral = integrate.simps(integral_1D, xlin)
    return integral

def forward_method(ele_pos, csd_at, true_csd):
    pots = np.zeros(ele_pos.shape[0])
    xlin = csd_at[0, :, 0]
    ylin = csd_at[1, 0, :]
    h = 50. # distance between the electrode plane and the CSD plane
    conductivity = 1.0 # S/m
    for ii in range(ele_pos.shape[0]):
        pots[ii] = integrate_2d(csd_at, true_csd,
        [ele_pos[ii][0], ele_pos[ii][1]], h,
        [xlin, ylin])
    return pots / (2 * np.pi * conductivity)

xmin = 0.0
xmax = 1.0
ymin = 0.0
ymax = 1.0
n_src_init = 1000
R_init = 1.
ext_x = 0.0
ext_y = 0.0
h = 50. # distance between the electrode plane and the CSD plane
conductivity = 1.0 # S/m

def do_kcsd(ele_pos, pots):
    pots = pots.reshape((len(ele_pos), 1)) # first time point
    return KCSD2D(ele_pos, pots, h=h, sigma=conductivity,
                  xmin=xmin, xmax=xmax,
                  ymin=ymin, ymax=ymax,
                  n_src_init=n_src_init,
                  src_type='gauss',
                  R_init=R_init)

csd_at = np.mgrid[0.:1.:100j,
                  0.:1.:100j]
csd_x, csd_y = csd_at
csd_pos = np.vstack((csd_x.flatten(), csd_y.flatten())).T

est_to_ele = slice(35, 65, 3)
ele_x, ele_y = csd_at[:, est_to_ele, est_to_ele]
ele_pos = np.vstack((ele_x.flatten(), ele_y.flatten())).T

k = do_kcsd(ele_pos, np.zeros(ele_pos.shape[0]))
kernel = k.k_pot
ck = k.k_interp_cross
true_csd = k.process_estimate(np.dot(ck, np.dot(kernel, linalg.svd(np.dot(np.linalg.inv(kernel), kernel))[2][0].reshape((-1, 1)))))[:, :, 0]

pots = forward_method(ele_pos, csd_at, true_csd)
true_pots = pots.reshape(ele_x.shape)

k = do_kcsd(ele_pos, pots)
est_csd = k.values('CSD')
kcsd_pot = k.values('POT')

make_plot(csd_x, csd_y, true_csd, 'True CSD')
make_plot(ele_x, ele_y, true_pots, 'True POT', cmap=cm.PRGn)

make_plot(csd_x, csd_y, est_csd[:,:,0], 'kCSD CSD')
make_plot(ele_x, ele_y, kcsd_pot[est_to_ele, est_to_ele, 0], 'kCSD POT', cmap=cm.PRGn)

make_plot(csd_x, csd_y, true_csd - est_csd[:,:,0], 'True - kCSD CSD')
make_plot(ele_x, ele_y, true_pots - kcsd_pot[est_to_ele, est_to_ele, 0], 'True - kCSD POT', cmap=cm.PRGn)
plt.show()
abukaj commented 5 years ago

The issue persists despite cross validation:

k.cross_validate(Rs=np.logspace(-2, 1, 16))
est_csd = k.values('CSD') 
kcsd_pot = k.values('POT')
make_plot(csd_x[34:66, 34:66], 
          csd_y[34:66, 34:66],  
          (true_csd - est_csd[:,:,0])[34:66, 34:66],  
          'True - kCSD CSD (center)')
make_plot(ele_x, ele_y, true_pots - kcsd_pot[est_to_ele, est_to_ele, 0],
          'True - kCSD POT', cmap=cm.PRGn)

Err_CSD_cv Err_POT_cv