LFPy / LFPykit

Freestanding implementations of electrostatic forward models for extracellular measurements of neural activity in multicompartment neuron models.
GNU General Public License v3.0
7 stars 3 forks source link

How to apply LFPykit on more than one neuron #146

Open llandsmeer opened 2 years ago

llandsmeer commented 2 years ago

The README states:

the extracellular potentials can be calculated from a distance-weighted sum of contributions from transmembrane currents of neurons

suggesting that LFPykit Is able to handle more than one neuron (which is logical given that we are looking at LFP's). Now, both documentation and examples show LFPykit applied to only one neuron, eg. the form PointSourcePotential(cell=CellGeometry(...)).

Does this mean that LFPykit can only handle a single neuron? Or is there the possibility to apply it to large networks of multicompartmental neurons (eg by just adding all neurons into a single CellGeometry)? Please let me know, we'd really like to apply this our network simulated using Arbor

Kind regards, Lennart Landsmeer

github-actions[bot] commented 2 years ago

Hello @llandsmeer , thank you for submitting an issue!

llandsmeer commented 2 years ago

So reading through some more code and Arbor/LFPy issues, the general process seems to be

1) calculate LFP's for each neuron using something like PointSourcePotential(cell=CellGeometry(...)) 2) 'reduce the individual contributions' - which just means summing together?

Is that correct?

espenhgn commented 2 years ago

Hello @llandsmeer; At this moment we haven't yet prepared any example network simulations incorporating extracellular signal predictions via LFPykit in Arbor. Right now there is only this single-cell example provided with Arbor: https://github.com/arbor-sim/arbor/blob/master/python/example/single_cell_extracellular_potentials.py (which is just a slightly reworked version of this notebook https://github.com/LFPy/LFPykit/blob/master/examples/Example_Arbor_swc.ipynb).

So in this context what you outline above would be the way to go:

In terms of what we have done with networks and LFPykit; LFPy itself supports computing extracellular signals from networks directly: https://github.com/LFPy/LFPy/blob/master/examples/example_network/example_network.py. In LFPy the procedure above is modified somewhat as it made sense with NEURON specifically. We build linear maps between compartments and measurements for all compartments across cells per MPI process pre simulation. Then at each time point in the simulation build a vector of compartment transmembrane currents and multiply this with the mapping matrix. Post simulation sum up across MPI processes and export the summed data. This avoids recording transmembrane currents altogether (which can be very memory demanding).

Please let me know if you have further questions. I think having an actual use case with networks in Arbor would be very handy.

espenhgn commented 2 years ago

Hi again @llandsmeer. For https://github.com/arbor-sim/arbor/pull/1825 I've now incorporated some improvements mainly to the notebook https://github.com/LFPy/LFPykit/blob/master/examples/Example_Arbor_swc.ipynb. In there I define a couple new classes ArborCellGeometry and ArborLineSourcePotential to simplify the code required to compute extracellular potentials.

llandsmeer commented 2 years ago

Hi @espenhgn, thanks so much for the detailed explanation and for updating the arbor example notebook with reusable classes!

Some information about the network:

Thanks to your messages and the example I got a minimal version working over here: https://github.com/llandsmeer/iopublic/blob/main/Local%20field%20potential.ipynb

Some notes:

The results are ... not that interesting yet :) download (this is a full network crosssection)

espenhgn commented 2 years ago

Hi @llandsmeer; Great that you could incorporate LFPykit, although the output is not as expected. It's also great to have an idea of network size/compartment count to get an idea of the magnitude of this problem.

I did try and run your notebook locally on my M1 Mac but failed as a CUDA library is needed during the recipe = iopublic.build_recipe(**kwargs) step. Hence I can mostly speculate what went wrong above:

I see Brent already gave some feedback to my PR, so I may update what I did here with the Arbor example(s)

espenhgn commented 2 years ago

Btw., I thought this plot sort of made sense if this indeed the extracellular potential of one cell (at index 30):

image

From what you showed above, perhaps you sum over the wrong axis? To me it looks like you should sum over axis=1, not axis=0

llandsmeer commented 2 years ago

Ah sorry I didn't explain that I think. Index 30 here is a spatial index. So we have time on the X-axis and 1-dimensional space on the Y axis. V_e already contains the summed up potentials over all cells. I'm pretty sure I'm summing over the right axes (else the output shape is not correct)

To what extent does it look wrong? I think I forgot to say that the inferior olive cells show threshold oscillations - about 2/3 of the cells in the network always oscillate. They also synchronize, and as far as I can see the global synchronization mode is displayed in the LFP's. But this is also the first time I'm doing this kind of analysis, so I might be doing something wrong of course (we're working on experimental validation of the LFP's currently).

llandsmeer commented 2 years ago

Writing this I just realized I might have made a very obvious mistake, I'll try to get back on this asap

espenhgn commented 2 years ago

Writing this I just realized I might have made a very obvious mistake, I'll try to get back on this asap

No worries!

llandsmeer commented 2 years ago

Yes I had an error - I was placing all segments at position (0, 0, 0). That didn't matter for the cable cell equations but obviously changes the LFPs..

Anyway I fixed that - now I get this, which looks much more logical (x, y) crosssection image

(x, t) crosssection image

espenhgn commented 2 years ago

Very nice! Did you manage to address the performance/memory issues you encountered in a good way (beyond reducing the number of locations in space)? I did open an issue here https://github.com/LFPy/LFPykit/issues/153. Seems to me that a few operations could be sped up using numba with relative ease.

llandsmeer commented 2 years ago

I just ran some small tests. The results are pretty obvious for you probably, but it was nice to dive into the codebase a bit more. In general, the slow calls are ArborCellGeometry() and lsp.get_transformation_matrix()

1) Directly using a concurrent.futures.ThreadPoolExecutor(max_workers=128 slows things down, with CPU load staying around 1, suggesting some GIL problems.

2) Replacing repeated calls to numpy.row_stack with appending to a list and np.array(, dtype=float) speeds up ArborCellGeometry.__init__ by about 2.5 times

class ArborCellGeometry(lfpykit.CellGeometry):
    def __init__(self, p, cables):
        x, y, z, r = [], [], [], []
        CV_ind = np.array([], dtype=int)  # tracks which CV owns segment
        for i, m in enumerate(cables):
            segs = p.segments([m])
            for j, seg in enumerate(segs):
                x.append([seg.prox.x, seg.dist.x])
                y.append([seg.prox.y, seg.dist.y])
                z.append([seg.prox.z, seg.dist.z])
                r.append([seg.prox.radius, seg.dist.radius])
                CV_ind = np.r_[CV_ind, i]
        x = np.array(x, dtype=float)
        y = np.array(y, dtype=float)
        z = np.array(z, dtype=float)
        d = 2*np.array(r, dtype=float)
        super().__init__(x=x, y=y, z=z, d=d)
        self._CV_ind = CV_ind

3) Seems like we're limited to calling ArborCellGeometry() single threaded for now as it's all python objects. So then I looked at using a ThreadPoolExecutor to speed up just lsp.get_transformation_matrix ... which only slowed things down by a factor of about 2x.

4) Within ArborLineSourcePotential.get_transformation_matrix, the call to LineSourcePotential.get_transformation_matrix() takes up 90% of the time, the CV reshuffling the rest. Within this call, the loop over segments, calling lfpcalc.calc_lfp_linesource(cell, x[j], y[j], z[j], sigma, r_limit) is the bottleneck

5) Within calc_lfp_linesource, the only input argument not being a float or ndarray is cell, which is immediately destructed into its constituent ndarrays. It calls the functions _deltaS_calc, _h_calc, _r2_calc,_linesource_calc_case{1,2,3}`, which all take in numpy arrays and output numpy arrays, suggesting they're fine with numba.jit(nopython)

So indeed, adding numba.jit(nopython=True, nogil=True) to some functions in lfpcalc.py seems like a good solution. Then we either need to destruct cell before calling the numba function or make it a jitclass. I'll try the first solution now in fork

llandsmeer commented 2 years ago

So this is the diff to get this working with numba on lfpcalc.py, which I got from pip install (I think?)

18a19
> import numba
20c21
<
---
> @numba.jit(nopython=True, nogil=True, cache=True)
372a374
>     return _calc_lfp_linesource(xstart, xend, ystart, yend, zstart, zend , x, y, z, sigma, r_limit)
373a376,377
> @numba.jit(nopython=True, nogil=True, cache=True)
> def _calc_lfp_linesource(xstart, xend, ystart, yend, zstart, zend , x, y, z, sigma, r_limit):
387c391
<     mapping = np.zeros(len(cell.x[:, 0]))
---
>     mapping = np.zeros(len(xstart))
440c444
<         print('Adjusting r-distance to root segments')
---
>         # print('Adjusting r-distance to root segments')
476a481
> @numba.jit(nopython=True, nogil=True, cache=True)
484a490
> @numba.jit(nopython=True, nogil=True, cache=True)
492a499
> @numba.jit(nopython=True, nogil=True, cache=True)
500a508
> @numba.jit(nopython=True, nogil=True, cache=True)
507a516
> @numba.jit(nopython=True, nogil=True, cache=True)
510,512c519,522
<     aa = np.array([x - xend, y - yend, z - zend])
<     bb = np.array([xend - xstart, yend - ystart, zend - zstart])
<     cc = np.sum(aa * bb, axis=0)
---
>     ccX = (x - xend) * (xend - xstart)
>     ccY = (y - yend) * (yend - ystart)
>     ccZ = (z - zend) * (zend - zstart)
>     cc = ccX + ccY + ccZ
516a527
> @numba.jit(nopython=True, nogil=True, cache=True)
520c531
<     return abs(r2)
---
>     return np.abs(r2)

Which you might want to check for correctness. Note that I had to remove a print statement. Originally, calling get_transformation_matrix on the first 100 lsps (line source potentials) took 2.5 seconds. After numba this was 5.3 seconds for the first time (because JIT) and 720ms for the following times. Using a ThreadPoolExecutor(max_workers=4) takes 1.28 seconds still. Load average also doesn't go above 1, suggesting there are still GIL problems even though its disabled inside the numba functions.

espenhgn commented 2 years ago

Great stuff! The lfpcalc.py file is indeed packaged with LFPykit (https://github.com/LFPy/LFPykit/blob/master/lfpykit/lfpcalc.py). The reason for assuming a class cell as input is merely historic and in hindsight could hurt performance. Again, I'm not quite sure about the potentials gains, but using the jitclass decorator for the classes CellGeometry,LineSourcePotential` could help.

There is a test suite in LFPykit which could help pick up errors in some calculations (py.test -v lfpykit/tests)

llandsmeer commented 2 years ago

Yes the cell class is the things that's going to complicate things I think. For example, this is how LineSourcePotential.get_transformation_matrix looks when we want to parallelize the calc_lfp_linesource invocation via numba.

It works, but looks very ugly

        xstart = self.cell.x[:, 0]
        xend = self.cell.x[:, -1]
        ystart = self.cell.y[:, 0]
        yend = self.cell.y[:, -1]
        zstart = self.cell.z[:, 0]
        zend = self.cell.z[:, -1]
        n = self.x.size
        @numba.jit(nopython=True, nogil=True, cache=True, fastmath=True, parallel=True)
        def speed_this_up(xstart, xend, ystart, yend, zstart, zend, n, x, y, z, sigma, r_limit, totnsegs):
            M = np.empty((n, totnsegs))
            for j in numba.prange(n):
                M[j, :] = lfpycalc._calc_lfp_linesource(
                        xstart, xend, ystart, yend, zstart, zend,
                        x=x[j], y=y[j], z=z[j], sigma=sigma, r_limit=r_limit)
            return M
        return speed_this_up(xstart, xend, ystart, yend, zstart, zend, n, self.x, self.y, self.z, self.sigma, r_limit, self.cell.totnsegs)

With the new parallel invocation method I get a load average of 41 which suggests we're actually using multiple cores now :). Still, total runtime is only 10 seconds faster..

# original:                                               6 mins 56 seconds
# numba.jit(nopython=True, nogil=True)                    2 mins 6
# numba.jit(nopython=True, nogil=True, fastmath=True)     2 mins 3 seconds
# numba.jit(..., parallel=True) + prange() in M = ...     1 mins 56 seconds

In my (limited) experience, combining OOP (readable, easy to use code) and numba (fast numerical code) is not trivial. The jitclass indeed looks like an easy solution at first sight, but it's really not that great in numba sadly. You lose the ability to add parameters to the numba.jit call like nogil=True, parallel=True or cache=True which you can only solve by editing the numba source code by hand (not that hard, but not a great user experience). Using it as a numba readable datacontainer might be an option but I haven't tried that yet. In general it'll just lead to slow startup times.

Running py.test I get 1 failed test for test_RecExtElectrode_04 where the difference is in the order of 1e-17.

I'll happily open a pull request for performance improvements, but currently I'm not sure what kind of tradeoffs I'm allowed to make between code readability and performance

llandsmeer commented 2 years ago

Btw the biggest problem I currently have is running out of 128 GB of RAM (!) when calculating the external potentials.. Still looking for a solution to that

espenhgn commented 2 years ago

Hi again! I don't have a clear idea why parallelization does not speed things up more in your case. I can speculate it adds additional overhead to calls to each call to lfpycalc.calc_lfp_linesource which is fairly fast per call, with or without turning it into an extension. I reapplied some of your suggestions here LFPykit@fix-153 and see a speedup factor of about 2 using the notebook examples/cProfile.ipynb (still not parallized):

branch fix-153:

         300548 function calls in 6.441 seconds

   Ordered by: cumulative time
   List reduced from 31 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    6.441    6.441 {built-in method builtins.exec}
        1    0.010    0.010    6.440    6.440 <string>:1(<module>)
      100    0.219    0.002    6.431    0.064 models.py:437(get_transformation_matrix)
   100100    0.277    0.000    6.211    0.000 lfpcalc.py:346(calc_lfp_linesource)
   100100    5.921    0.000    5.934    0.000 lfpcalc.py:377(_calc_lfp_linesource)
   100100    0.013    0.000    0.013    0.000 serialize.py:29(_numba_unpickle)
        1    0.000    0.000    0.000    0.000 {built-in method io.open}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.print}
        2    0.000    0.000    0.000    0.000 iostream.py:502(write)
        3    0.000    0.000    0.000    0.000 iostream.py:208(schedule)
        3    0.000    0.000    0.000    0.000 socket.py:480(send)
      100    0.000    0.000    0.000    0.000 {built-in method numpy.empty}
        1    0.000    0.000    0.000    0.000 {method 'read' of '_io.TextIOWrapper' objects}
        3    0.000    0.000    0.000    0.000 threading.py:1126(is_alive)
        2    0.000    0.000    0.000    0.000 iostream.py:439(_schedule_flush)
        1    0.000    0.000    0.000    0.000 _bootlocale.py:33(getpreferredencoding)
        3    0.000    0.000    0.000    0.000 threading.py:1059(_wait_for_tstate_lock)
        1    0.000    0.000    0.000    0.000 codecs.py:319(decode)
        2    0.000    0.000    0.000    0.000 iostream.py:420(_is_master_process)
        1    0.000    0.000    0.000    0.000 codecs.py:309(__init__)

*** Profile printout saved to text file 'prun0'.

vs. branch master

         2102348 function calls in 11.345 seconds

   Ordered by: cumulative time
   List reduced from 39 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   11.345   11.345 {built-in method builtins.exec}
        1    0.011    0.011   11.345   11.345 <string>:1(<module>)
      100    0.356    0.004   11.334    0.113 models.py:437(get_transformation_matrix)
   100100    3.750    0.000   10.978    0.000 lfpcalc.py:346(calc_lfp_linesource)
   100100    1.620    0.000    1.620    0.000 lfpcalc.py:508(_h_calc)
   100100    1.492    0.000    1.492    0.000 lfpcalc.py:477(_linesource_calc_case1)
   100100    1.234    0.000    1.234    0.000 lfpcalc.py:519(_r2_calc)
   100100    1.016    0.000    1.016    0.000 lfpcalc.py:501(_deltaS_calc)
   400400    0.084    0.000    0.821    0.000 <__array_function__ internals>:177(where)
   400400    0.718    0.000    0.718    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
   100100    0.659    0.000    0.659    0.000 lfpcalc.py:493(_linesource_calc_case3)
   100100    0.309    0.000    0.309    0.000 lfpcalc.py:485(_linesource_calc_case2)
   100100    0.073    0.000    0.073    0.000 {built-in method numpy.zeros}
   400400    0.019    0.000    0.019    0.000 multiarray.py:341(where)
   100102    0.004    0.000    0.004    0.000 {built-in method builtins.len}
        1    0.000    0.000    0.000    0.000 {built-in method io.open}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.print}
        2    0.000    0.000    0.000    0.000 iostream.py:502(write)
        3    0.000    0.000    0.000    0.000 iostream.py:208(schedule)
      100    0.000    0.000    0.000    0.000 {built-in method numpy.empty}

*** Profile printout saved to text file 'prun0'.

As for refactoring the lfpcalc.py methods, in particular the hidden ones, I'd say these can be refactored with regards to performance, not readability. I think the main functions lfpcalc.calc_lfp_linesource should be exposed (nopython=False presumably).

These functions should also not receive a CellGeometry object as input. I don't see any particular reason to not turn cell into numpy array arguments cell_x with columns xstart, xend etc. Somewhat related, in https://github.com/LFPy/LFPykit/issues/158 I would like to fix some issues with how the different probes and sources are represented, so "breaking" backwards compatibility shouldn't be an issue for say the next LFPykit release.

Not sure what's going on with the failing test. I also got this if I set fastmath=True for _calc_lfp_linesource in the above PR.

The memory issue is indeed a problem, in part incurred by Arbor. My preferred solution would be to precompute the mapping from the cell geometry using LFPykit, provide this mapping to Arbor (or let Arbor enquire LFPykit), and record the extracellular signal during the simulation rather than recording transmembrane currents. My take is that for most use cases the # electrode contact points << # CVs so only storing the signals should consume less memory. But herein is the issue is that Arbor presently do not expose the geometry before running the simulation. Furthermore, I don't know how you compute the summed signal at this point. You could consider concatenating the outputs of get_transformation_matrix along the last axis for each cell, and the transmembrane currents along their first axis. Then you could compute the signals as the product M_all @ I_all (with shapes (# contacts, # cells * # CVs per cell) and (# cells * # CVs per cell, # timesteps), respectively). The corresponding matrices would be pretty big but should not consume all your memory (assuming 4096 contacts; 200000 CVs, 6000 time steps, double precision): M_all: (4096 200000 64[bit] / 8[bit/B] / 1024^3[B/GB]) = 6.1GB I_all: (200000 6000 64[bit] / 8[bit/B] / 1024^3[B/GB]) = 8.9GB product: (4096 6000 64[bit] / 8[bit/B] / 1024^3[B/GB]) = 0.18 GB

espenhgn commented 2 years ago

Parallelization yields somewhat better performance, but doesn't reflecting the number of cores available (M1 mac) (commit 0470f71772d31a76eca1fb7962340981777df00b) :

         2417475 function calls (2403075 primitive calls) in 3.425 seconds

   Ordered by: cumulative time
   List reduced from 443 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    3.425    3.425 {built-in method builtins.exec}
        1    0.024    0.024    3.425    3.425 <string>:1(<module>)
      100    0.003    0.000    3.103    0.031 models.py:438(get_transformation_matrix)
      100    1.837    0.018    1.840    0.018 models.py:460(_get_transform)
62600/62500    0.911    0.000    1.296    0.000 ffi.py:149(__call__)
      100    0.001    0.000    1.214    0.012 dispatcher.py:388(_compile_for_args)
      100    0.002    0.000    1.202    0.012 dispatcher.py:915(compile)
      100    0.000    0.000    1.189    0.012 caching.py:639(load_overload)
      100    0.001    0.000    1.176    0.012 caching.py:650(_load_overload)
      100    0.000    0.000    1.145    0.011 caching.py:404(rebuild)
      100    0.001    0.000    1.145    0.011 compiler.py:210(_rebuild)
      100    0.000    0.000    1.120    0.011 codegen.py:1158(unserialize_library)
      100    0.001    0.000    1.120    0.011 codegen.py:926(_unserialize)
      100    0.000    0.000    0.541    0.005 module.py:29(parse_bitcode)
      200    0.003    0.000    0.501    0.003 codegen.py:1088(_load_defined_symbols)
      400    0.015    0.000    0.492    0.001 codegen.py:1092(<setcomp>)
    21500    0.008    0.000    0.324    0.000 ffi.py:356(__del__)
    21500    0.009    0.000    0.314    0.000 ffi.py:313(close)
      100    0.001    0.000    0.292    0.003 module.py:76(_dispose)
    62600    0.022    0.000    0.197    0.000 ffi.py:73(__exit__)

*** Profile printout saved to text file 'prun0'.

I also get a warning OMP: Info #271: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead. (as well as a bunch of broken tests).

I also did test type hinting, but this did not improve anything (commit cb019a0a7057a93a09fd206b39b2601b6c40dec6). Maybe I incorporated this wrong or something.

github-actions[bot] commented 2 years ago

This issue appears to be stale due to non-activity and will be closed automatically