BlueBrain / voxcell

Tools to work with voxel based brain atlases.
Apache License 2.0
5 stars 3 forks source link

Harvest the categorical power in `CellCollection.save_sonata` #14

Closed eleftherioszisis closed 1 year ago

eleftherioszisis commented 2 years ago

Use pandas.factorize to make use of categorical columns and avoid sorting/unique on them. Pandas factorize also doesn't perform by default sorting compared to numpy.unique.

Benchmarks:

Test dataset with categorical columns (N >> M):

         morph_class region  layer synapse_class      etype      mtype
0                PYR    AAA      1           EXC  GEN_etype  GEN_mtype
1                PYR    AAA      1           EXC  GEN_etype  GEN_mtype
2                PYR    AAA      1           EXC  GEN_etype  GEN_mtype
3                PYR    AAA      1           EXC  GEN_etype  GEN_mtype
4                PYR    AAA      1           EXC  GEN_etype  GEN_mtype
...              ...    ...    ...           ...        ...        ...
74129106         INT      y      1           INH  GIN_etype  GIN_mtype
74129107         INT      y      1           INH  GIN_etype  GIN_mtype
74129108         INT      y      1           INH  GIN_etype  GIN_mtype
74129109         INT      y      1           INH  GIN_etype  GIN_mtype
74129110         INT      y      1           INH  GIN_etype  GIN_mtype

[74129111 rows x 6 columns]

Before:

cProfile.run("v.save_sonata('old.h5')", sort="tottime")
         1841 function calls (1820 primitive calls) in 165.646 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        5  149.990   29.998  149.990   29.998 {method 'argsort' of 'numpy.ndarray' objects}
        5    6.121    1.224  158.594   31.719 arraysetops.py:323(_unique1d)

After:

cProfile.run("v.save_sonata('new.h5')", sort="tottime")
         3616 function calls (3565 primitive calls) in 3.849 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       15    1.929    0.129    1.930    0.129 dataset.py:36(make_new_dset)
        4    1.213    0.303    1.213    0.303 {method 'factorize' of 'pandas._libs.hashtable.Int8HashTable' objects}
        1    0.321    0.321    0.321    0.321 {method 'factorize' of 'pandas._libs.hashtable.Int16HashTable' objects}

Test dataset with 200M unique strings (N=M):

import voxcell
v = voxcell.CellCollection("test")
v.add_properties({"myprop": [str(u) for u in range(200_000_000)]})

Before:

cProfile.run("v.save_sonata('old.h5')", sort="tottime")
         259 function calls (252 primitive calls) in 204.472 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1  111.898  111.898  111.898  111.898 {method 'argsort' of 'numpy.ndarray' objects}
        2   55.541   27.770   64.819   32.409 dataset.py:36(make_new_dset)
        1   13.771   13.771  204.472  204.472 <string>:1(<module>)

After:

cProfile.run("v.save_sonata('new.h5')", sort="tottime")
         486 function calls (478 primitive calls) in 159.944 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1   86.519   86.519   86.519   86.519 {method 'factorize' of 'pandas._libs.hashtable.StringHashTable' objects}
        2   54.228   27.114   63.130   31.565 dataset.py:36(make_new_dset)
        1    8.902    8.902    8.902    8.902 base.py:60(<setcomp>)

For completeness, I also checked how the sort flag in factorize affects the times. In the case of already categorized columns, the time difference is not significant. However, in the case of a list of unique strings the time is significantly higher when sort=True:

cProfile.run("v.save_sonata('test3.h5')", sort="tottime")
         566 function calls (557 primitive calls) in 387.856 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2  151.046   75.523  160.234   80.117 dataset.py:36(make_new_dset)
        2  142.137   71.068  142.137   71.068 {method 'argsort' of 'numpy.ndarray' objects}
        1   66.579   66.579   66.579   66.579 {method 'factorize' of 'pandas._libs.hashtable.StringHashTable' objects}