neuronsimulator / nrn

NEURON Simulator
http://nrn.readthedocs.io
Other
406 stars 118 forks source link

RxD calls to species.nodes is excessively slow #2950

Open Hjorthmedh opened 4 months ago

Hjorthmedh commented 4 months ago

Context

The call myspecies.nodes to get all compartments that have myspecies is very slow. The call h.finitialize() is also very slow.

Overview of the issue

In our minimal example it takes 11 seconds for a single call, with one neuron, and 110 seconds for a single call, in a network of 10 neurons. This appears to scale linearly with the number of neurons even though we are only requesting a list of compartments from a single neuron.

h.finitialize() is also incredibly slow.

[Provide more details regarding the issue]

We expected the function call to be much faster, and be independent of the number of neurons. This is especially important since we want to run large scale networks of neurons (10000+).

NEURON setup

Minimal working example - MWE

This example uses the morphologies in https://github.com/Hjorthmedh/BasalGangliaData/tree/main/data/neurons/striatum/dspn (The code uses glob to extract swc files from /morphology/.swc)

import neuron
from neuron import h
import neuron.crxd as rxd

import bluepyopt.ephys as ephys
import glob
import time

# h.load_file("stdlib.hoc")
# h.load_file("import3d.hoc")

def load_morphology(self):
    cell = h.Import3d_SWC_read()
    cell.input("c91662.swc")
    i3d = h.Import3d_GUI(cell, False)
    i3d.instantiate(self)

    return cell

def minimal_example():

    sim = ephys.simulators.NrnSimulator()

    cell_list = []
    morph_dir = "/home/hjorth/HBP/BasalGangliaData/data/neurons/striatum/dspn/*/morphology/*"
    morphs = glob.glob(morph_dir + "*.swc")

    species_list = []
    region_list = []

    num_morphs = 10

    for idx, swc_file in enumerate(morphs[:num_morphs]):

        print(f"Loading morphology {swc_file} ({idx})")

        morph = ephys.morphologies.NrnFileMorphology(swc_file,
                                                     do_replace_axon=False)
        cell = ephys.models.CellModel('simple_cell',
                                      morph=morph)
        cell.instantiate(sim=sim)
        cell_list.append(cell)

        print("Creating regions", flush=True)
        region = rxd.Region(cell.icell.dend, nrn_region="i")
        region_list.append(region)

        print("Creating species", flush=True)
        for idx in range(0,200):
            species_name = f"species{idx}"
            spec = rxd.Species(region,
                               d=0,
                               initial=1,
                               charge=0,
                               name=species_name)
            species_list.append(spec)

            #soma.insert("caldyn_ms")

    duration = []

    for idx, spec in enumerate(species_list):
        # This step is slow

        if idx == 0:
            print(f"Calling nodes on {spec} -- This is slow!", flush=True)
        else:
            print(f"Calling nodes on {spec}", flush=True)

        start_time = time.time()
        spec.nodes
        end_time = time.time()
        dur = end_time - start_time
        duration.append(dur)
        print(f"nodes call done {dur}")

    print(f"Max duration: {max(duration)} for {num_morphs} neurons")

    print("Init")
    h.finitialize()
    # neuron.run(100)

import cProfile
prof_file = "profile.txt"
cProfile.runctx("minimal_example()", None, locals(), filename=prof_file)

import pstats
from pstats import SortKey
p = pstats.Stats(prof_file)
p.strip_dirs().sort_stats(SortKey.CUMULATIVE).print_stats(100)

Logs

Calling nodes on species199
nodes call done 0.0003521442413330078
**Max duration: 107.14342880249023 for 10 neurons**
Init

The init call takes AGES. Included below is the output from the python profiler.

Fri Jun 28 12:24:01 2024    profile.txt

         602884481 function calls (602882440 primitive calls) in 1993.205 seconds

   Ordered by: cumulative time
   List reduced from 339 to 100 due to restriction <100>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000 1993.279 1993.279 {built-in method builtins.exec}
     2000    2.548    0.001 1839.758    0.920 species.py:1830(__del__)

This function call is done excessively, taking up 86.7% of the total run time! (Mostly during initialize?)

   264000 1577.023    0.006 1729.590    0.007 section1d.py:218(_delete)
        1    0.082    0.082  149.173  149.173 minimal_example_setup.py:21(minimal_example)
  2001000   54.289    0.000  111.427    0.000 species.py:2060(_update_region_indices)
10000/8000    0.205    0.000  110.628    0.014 species.py:2266(nodes)
   132000    2.574    0.000  106.803    0.001 node.py:62(_remove)
   528000   97.567    0.000  104.129    0.000 function_base.py:5173(delete)
        2    0.000    0.000   81.657   40.829 rxd.py:481(_setup)
      8/6    0.011    0.001   81.657   13.610 initializer.py:20(_do_init)

The update_node_data is also run excessively, especially during the myspecies.nodes call

     6000    0.150    0.000   74.866    0.012 species.py:1970(_update_node_data)
   396000    1.786    0.000   74.681    0.000 section1d.py:196(_update_node_data)
        3    0.064    0.021   73.907   24.636 rxd.py:1852(_init)
 132066000   46.190    0.000   56.575    0.000 section1d.py:255(indices)
    10005    0.018    0.000   50.080    0.005 rxd.py:604(_update_node_data)
   132000    0.056    0.000   34.124    0.000 section1d.py:252(__del__)
   396000   17.895    0.000   32.157    0.000 geometry.py:68(result)
   396000   11.753    0.000   26.164    0.000 geometry.py:36(_volumes1d)
     2000    0.021    0.000   25.053    0.013 species.py:1818(_do_init5)
   264000    5.284    0.000   24.358    0.000 section1d.py:225(<listcomp>)
        3    1.147    0.382   22.373    7.458 rxd.py:665(_setup_matrices)
     6000    0.453    0.000   21.426    0.004 species.py:2007(_register_cptrs)
   396000    1.784    0.000   20.964    0.000 section1d.py:308(_register_cptrs)
  2409222   19.250    0.000   19.250    0.000 {built-in method builtins.getattr}
 26400000    4.604    0.000   19.074    0.000 section1d.py:188(__ne__)
     8000    0.129    0.000   17.788    0.002 species.py:2180(_assign_parents)
   528000    0.813    0.000   17.645    0.000 section1d.py:399(_assign_parents)
   397320   15.709    0.000   16.221    0.000 section1d.py:33(__getitem__)
  2834400    2.214    0.000   15.278    0.000 function_base.py:1461(interp)
 26400000    8.707    0.000   14.471    0.000 section1d.py:183(__eq__)
 33515200   14.259    0.000   14.259    0.000 section1d.py:169(_sec)
   396000    1.869    0.000   14.023    0.000 geometry.py:117(_neighbor_areas1d)
     2000    0.007    0.000   12.007    0.006 species.py:1729(_do_init2)
139287600   11.012    0.000   11.012    0.000 rxdsection.py:16(nseg)
   528000    0.299    0.000   10.510    0.000 species.py:2169(_has_region_section)
   543010    1.143    0.000   10.204    0.000 {built-in method builtins.any}
     6000    7.886    0.001    9.471    0.002 species.py:1875(_ion_register)
 15264000    3.494    0.000    9.077    0.000 species.py:2170(<genexpr>)
     2000    0.073    0.000    7.311    0.004 species.py:1735(<listcomp>)
   132000    3.748    0.000    7.238    0.000 section1d.py:154(__init__)
  2834400    6.891    0.000    6.891    0.000 {built-in method numpy.core._multiarray_umath.interp}
  3890400    6.444    0.000    6.444    0.000 {built-in method numpy.asarray}
   464000    1.812    0.000    5.783    0.000 species.py:2172(_region_section)
     8000    0.291    0.000    5.317    0.001 species.py:2199(_finitialize)
     6000    0.092    0.000    5.095    0.001 species.py:2148(_setup_diffusion_matrix)
   396000    3.382    0.000    5.004    0.000 section1d.py:321(_setup_diffusion_matrix)
   528000    4.131    0.000    4.774    0.000 numeric.py:136(ones)
   396000    3.306    0.000    4.443    0.000 function_base.py:24(linspace)
    10000    0.128    0.000    4.313    0.000 species.py:2290(<listcomp>)
   660000    0.635    0.000    4.185    0.000 section1d.py:292(nodes)
 52800000    3.671    0.000    3.671    0.000 {built-in method builtins.id}
     2000    0.055    0.000    3.525    0.002 species.py:1645(_do_init1)
     1320    3.342    0.003    3.343    0.003 section1d.py:45(__delitem__)
        1    0.003    0.003    3.100    3.100 initializer.py:7(_do_ion_register)
     2000    0.018    0.000    3.028    0.002 species.py:1534(__init__)
 23810400    2.881    0.000    2.881    0.000 {method 'arc3d' of 'nrn.Section' objects}
   396000    1.960    0.000    2.875    0.000 geometry.py:72(<listcomp>)
   396000    1.888    0.000    2.854    0.000 geometry.py:39(<listcomp>)
   396000    1.884    0.000    2.841    0.000 geometry.py:71(<listcomp>)
   396000    1.874    0.000    2.831    0.000 geometry.py:120(<listcomp>)
   660000    0.930    0.000    2.802    0.000 section1d.py:296(<listcomp>)
   396000    1.871    0.000    2.799    0.000 geometry.py:40(<listcomp>)
   396000    1.862    0.000    2.779    0.000 geometry.py:121(<listcomp>)
 23810400    2.761    0.000    2.761    0.000 {method 'diam3d' of 'nrn.Section' objects}
   263340    2.606    0.000    2.669    0.000 section1d.py:41(__setitem__)
 29878920    2.466    0.000    2.474    0.000 {built-in method builtins.isinstance}
  3088000    0.885    0.000    2.153    0.000 rxdsection.py:47(L)
        3    1.524    0.508    1.524    0.508 rxd.py:700(<listcomp>)
        1    0.001    0.001    1.266    1.266 rxd.py:1936(_init_concentration)
        3    0.005    0.002    1.142    0.381 rxd.py:379(_setup_memb_currents)
     6000    0.080    0.000    1.137    0.000 species.py:2157(_setup_currents)
  1219200    1.129    0.000    1.129    0.000 geometry.py:47(<listcomp>)
  1219200    1.125    0.000    1.125    0.000 geometry.py:79(<listcomp>)
  2834400    0.801    0.000    1.051    0.000 type_check.py:302(iscomplexobj)
   396000    0.463    0.000    1.050    0.000 section1d.py:259(_setup_currents)
     8000    0.075    0.000    0.783    0.000 species.py:2218(_transfer_to_legacy)
       10    0.000    0.000    0.736    0.074 models.py:253(instantiate)
       10    0.018    0.002    0.736    0.074 models.py:226(instantiate_morphology)
  2011200    0.718    0.000    0.718    0.000 section1d.py:68(add_values)
       10    0.711    0.071    0.716    0.072 morphologies.py:111(instantiate)
   528000    0.621    0.000    0.707    0.000 section1d.py:299(_transfer_to_legacy)
  1625600    0.310    0.000    0.617    0.000 node.py:196(concentration)
  1154740    0.589    0.000    0.589    0.000 {method 'hoc_internal_name' of 'nrn.Section' objects}
   528000    0.533    0.000    0.533    0.000 {built-in method numpy.empty}
  2032000    0.497    0.000    0.497    0.000 node.py:369(__init__)
  4863605    0.457    0.000    0.457    0.000 {method 'extend' of 'list' objects}
  1188000    0.285    0.000    0.407    0.000 rxdsection.py:24(nrn_region)
   396003    0.163    0.000    0.364    0.000 section1d.py:50(__iter__)
   134000    0.097    0.000    0.359    0.000 node.py:47(_allocate)
4211759/4211750    0.356    0.000    0.356    0.000 {built-in method builtins.len}
        3    0.339    0.113    0.339    0.113 rxd.py:224(_list_to_pyobject_array)
  2834400    0.332    0.000    0.332    0.000 function_base.py:1457(_interp_dispatcher)
  2834400    0.325    0.000    0.325    0.000 type_check.py:205(_is_type_dispatcher)
   132000    0.248    0.000    0.309    0.000 section1d.py:53(values)
  1625600    0.307    0.000    0.307    0.000 node.py:240(value)
  2599882    0.301    0.000    0.301    0.000 {built-in method builtins.hasattr}
   792003    0.283    0.000    0.283    0.000 {built-in method numpy.zeros}
     6000    0.242    0.000    0.274    0.000 species.py:2152(_setup_c_matrix)
  2376000    0.272    0.000    0.272    0.000 {method 'n3d' of 'nrn.Section' objects}
   528000    0.181    0.000    0.271    0.000 section1d.py:81(_parent)
pramodk commented 4 months ago

@adamjhn or @ramcdougal: do you have a suggestion for this issue?

ramcdougal commented 4 months ago

Looking into this...

First two insights:

Quite frankly, this is probably the first time anyone has tried to run this with 2000 species.

Here's a version reproducing the problem that doesn't require bluepyopt:

from neuron import h, rxd
import time

# I'm using c91662 from
# https://raw.githubusercontent.com/NeuroBox3D/NeuGen/master/NeuGen/cellData/CA1/amaral/c91662.CNG.swc

h.load_file("import3d.hoc")

class Cell:
    def __init__(self):
        cell = h.Import3d_SWC_read()
        cell.input("c91662.swc")
        i3d = h.Import3d_GUI(cell, False)
        i3d.instantiate(self)

def minimal_example(NUM_MORPHS=10, SPECIES_PER_CELL=200):
    species_list = []
    region_list = []

    cell_list = [Cell() for _ in range(NUM_MORPHS)]

    for cell in cell_list:
        print("Creating regions", flush=True)
        region = rxd.Region(cell.dend, nrn_region="i")
        region_list.append(region)

        print("Creating species", flush=True)
        for idx in range(SPECIES_PER_CELL):
            species_name = f"species{idx}"
            spec = rxd.Species(region,
                               d=0,
                               initial=1,
                               charge=0,
                               name=species_name)
            species_list.append(spec)

    duration = []

    for idx, spec in enumerate(species_list):
        # This step is slow

        if idx == 0:
            print(f"Calling nodes on {spec} -- This is slow!", flush=True)
        else:
            print(f"Calling nodes on {spec}", flush=True)

        start_time = time.perf_counter()
        spec.nodes
        end_time = time.perf_counter()
        dur = end_time - start_time
        duration.append(dur)
        print(f"nodes call done {dur}")

    print(f"Max duration: {max(duration)} for {NUM_MORPHS} neurons")

    print("Init")
    start_time = time.perf_counter()
    h.finitialize(-65)
    end_time = time.perf_counter()
    print(f"Initialization time: {end_time - start_time} seconds")

if __name__ == "__main__":
    minimal_example()
Hjorthmedh commented 4 months ago

Hej! Do you have any suggestions for how to solve this?

ramcdougal commented 4 months ago

A couple quick tips for now:

Short-term solutions we can help with:

Longer-term fixes we should do:

wthun commented 4 months ago

Thanks! We made some changes to retrieve nodes after all species were defined, but this still left the first call to .nodes too slow to be used at scale. It also seem to leave the time complexity for a single call proportional to the global number of nodes instead of only the number present in the individual cell. How complicated would it be to fix this at a lower-level? (It could be useful to have a tree representation or similar of the nodes at the c++ level instead of fetching continuously fetching Python lists.)

As for the deallocation, it seems to not happen when all species are involved in rxd.Reactions, so that's not been a problem in practice.

ctrl-z-9000-times commented 2 months ago

Hi,

I noticed another issue. Segment geometry is slow to compute and is computed multiple times per segment.

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   396000   17.895    0.000   32.157    0.000 geometry.py:68(result) <- This is "surfacearea1d"
   396000   11.753    0.000   26.164    0.000 geometry.py:36(_volumes1d)
   396000    1.869    0.000   14.023    0.000 geometry.py:117(_neighbor_areas1d)