hyanwong / giglib

MIT License
4 stars 2 forks source link

Speed up forward simulations #86

Open hyanwong opened 8 months ago

hyanwong commented 8 months ago

Profiling the test_inversion code with the cProfiler gives this as the first 20 or so lines:

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   15.958   15.958 gigutil.py:260(run_more)
        5    0.014    0.003   13.953    2.791 gigutil.py:65(new_population)
      804    0.050    0.000   13.895    0.017 gigutil.py:298(add_inheritance_paths)
      804    0.019    0.000   13.071    0.016 gigutil.py:279(find_comparable_points)
      804    0.644    0.001   13.007    0.016 graph.py:353(find_mrca_regions)
      103    0.006    0.000    6.077    0.059 __init__.py:1(<module>)
    15459    0.073    0.000    3.770    0.000 interval.py:525(__sub__)
    33908    0.038    0.000    3.670    0.000 tables.py:15(asdict)
    33908    0.056    0.000    3.632    0.000 dataclasses.py:1217(asdict)
     3507    0.068    0.000    3.563    0.001 dict.py:221(combine)
57157/50750    0.950    0.000    3.555    0.000 interval.py:409(__and__)
330920/33908    1.145    0.000    3.541    0.000 dataclasses.py:1241(_asdict_inner)
    24511    0.106    0.000    3.006    0.000 graph.py:569(__getitem__)
   223449    1.281    0.000    2.759    0.000 interval.py:98(from_atomic)
    32136    0.131    0.000    2.720    0.000 graph.py:206(iedges_for_child)
    12308    0.113    0.000    2.224    0.000 dict.py:291(__setitem__)
    805/3    0.009    0.000    1.989    0.663 <frozen importlib._bootstrap>:1022(_find_and_load)
    799/3    0.006    0.000    1.989    0.663 <frozen importlib._bootstrap>:987(_find_and_load_unlocked)
    774/3    0.007    0.000    1.988    0.663 <frozen importlib._bootstrap>:664(_load_unlocked)
    680/3    0.005    0.000    1.987    0.662 <frozen importlib._bootstrap_external>:877(exec_module)
   1026/3    0.002    0.000    1.984    0.661 <frozen importlib._bootstrap>:233(_call_with_frames_removed)
        7    0.000    0.000    1.945    0.278 tables.py:426(graph)
        7    0.000    0.000    1.945    0.278 graph.py:32(__init__)
        7    0.041    0.006    1.944    0.278 graph.py:45(_validate)
        2    0.000    0.000    1.734    0.867 tables.py:1(<module>)
    15459    0.160    0.000    1.444    0.000 interval.py:512(__invert__)
298266/297015    0.754    0.000    1.409    0.000 copy.py:128(deepcopy)
    23375    0.042    0.000    1.340    0.000 dict.py:34(__init__)
   149097    0.153    0.000    1.270    0.000 interval.py:398(__iter__)
    14028    0.086    0.000    1.219    0.000 dict.py:270(__getitem__)
   281354    0.580    0.000    1.162    0.000 interval.py:38(__init__)
   149097    0.141    0.000    1.123    0.000 interval.py:399(<genexpr>)
    27244    0.084    0.000    1.020    0.000 sorteddict.py:280(__setitem__)
...

Quite a lot of time is spend calling .asdict() on line 15 of tables.py., possibly because it is called whenever accessing a table row?

And in terms of the actual time taken for each inner function:

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   223449    1.575    0.000    3.387    0.000 interval.py:98(from_atomic)
57157/50750    1.343    0.000    4.265    0.000 interval.py:409(__and__)
330920/33908    1.305    0.000    4.069    0.000 dataclasses.py:1241(_asdict_inner)
298266/297015    0.852    0.000    1.629    0.000 copy.py:128(deepcopy)
   281354    0.806    0.000    1.464    0.000 interval.py:38(__init__)
      804    0.755    0.001   15.536    0.019 graph.py:353(find_mrca_regions)
  1358334    0.679    0.000    0.889    0.000 interval.py:175(empty)
  2033597    0.553    0.000    0.553    0.000 {built-in method builtins.getattr}
  2202653    0.526    0.000    0.547    0.000 {built-in method builtins.isinstance}
     2522    0.496    0.000    0.895    0.000 tables.py:73(<listcomp>)
      685    0.398    0.001    0.398    0.001 {method 'read' of '_io.BufferedReader' objects}
    33241    0.364    0.000    0.430    0.000 {method 'sort' of 'list' objects}
   481046    0.361    0.000    0.680    0.000 interval.py:148(lower)
   474037    0.355    0.000    0.533    0.000 const.py:39(__neg__)
   462253    0.349    0.000    0.661    0.000 interval.py:157(upper)
2066464/2062927    0.341    0.000    0.343    0.000 {built-in method builtins.len}
    54123    0.317    0.000    0.631    0.000 interval.py:549(__lt__)
    21109    0.304    0.000    1.115    0.000 dict.py:7(_sort)
...
hyanwong commented 8 months ago

If we move the find_mrca code into tables.py (and define iedges_for_child as a function of a table) then we shouldn't hit the asdict() method as much: this method is only used to wrap edges for a GIG, not for a table row.

hyanwong commented 8 months ago

You can test a simulation by creating a file called e.g. simulate.py in the top level directory containing e.g. the following lines

from tests.test_gigutil import TestDTWF_one_break_no_rec_inversions_slow

test = TestDTWF_one_break_no_rec_inversions_slow()
test.seq_len=1000  # to match the old code
test.default_gens = 20  # increase number of gens for more consistent profiling
test.test_inversion()

Then try the following on the command-line:

python -m cProfile -s cumulative simulate.py

It's not quite a like-for-like comparison because it is using a different random number seed (well a different order or RNG calls). But on my old desktop, the new code in #90 goes from 77 seconds to 62 seconds and avoids some of the asdict() calls. Most of the time is now spent in the interval library, so if there is a way to speed that up e.g. using numpy and/or numba, that would be very useful (interestingly the numba docs give an example of creating an interval class but it's quite involved). Kevin Thornton pointed me to https://pypi.org/project/intervaltree/ - I don't know if this is faster than the Portion library we are currently using. I suspect that some of the more complex features like IntervalDicts will be hard to find elsewhere.

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    936/1    0.053    0.000   61.395   61.395 {built-in method builtins.exec}
        1    0.004    0.004   61.395   61.395 simulate.py:1(<module>)
        1    0.007    0.007   60.068   60.068 test_gigutil.py:199(test_inversion)
        1    0.000    0.000   54.978   54.978 gigutil.py:282(run_more)
       10    0.026    0.003   53.721    5.372 gigutil.py:65(new_population)
     1804    0.090    0.000   53.607    0.030 gigutil.py:318(add_inheritance_paths)
     1804    0.039    0.000   49.789    0.028 gigutil.py:299(find_comparable_points)
     1804    2.748    0.002   49.674    0.028 tables.py:751(find_mrca_regions)
315241/277150    5.185    0.000   18.462    0.000 interval.py:409(__and__)
    52008    0.297    0.000   17.425    0.000 interval.py:525(__sub__)
    12600    0.214    0.000   16.180    0.001 dict.py:221(combine)
  1098284    6.657    0.000   14.458    0.000 interval.py:98(from_atomic)
    48716    0.400    0.000    8.788    0.000 dict.py:291(__setitem__)
   719104    0.675    0.000    6.816    0.000 interval.py:398(__iter__)
   719104    0.717    0.000    6.198    0.000 interval.py:399(<genexpr>)
    52008    0.479    0.000    5.950    0.000 interval.py:512(__invert__)
    50400    0.323    0.000    5.501    0.000 dict.py:270(__getitem__)
  1373414    2.192    0.000    5.407    0.000 interval.py:38(__init__)
    82441    0.129    0.000    5.000    0.000 dict.py:34(__init__)
    11018    1.946    0.000    4.579    0.000 tables.py:109(__getattr__)
   132852    0.354    0.000    4.339    0.000 sorteddict.py:280(__setitem__)
    54210    0.044    0.000    4.066    0.000 tables.py:24(asdict)
    54210    0.063    0.000    4.021    0.000 dataclasses.py:1299(asdict)
496856/54210    1.081    0.000    3.897    0.000 dataclasses.py:1323(_asdict_inner)
hyanwong commented 8 months ago

Slowness as we increase the number of generations should be fixed by continual simplifying (see https://github.com/hyanwong/GeneticInheritanceGraphLibrary/issues/64#issuecomment-1971411875). Note that there are a very large number of breakpoints in the test_inversion example, because we have one breakpoint per generation which equates to a recombination rate of 1e-3 in a 1000bp genome. In other words, we are simulating an entire chromosome here.

hyanwong commented 3 months ago

I think the two major ways in which we could speed up forward simulation are by implementing:

It would be worth implementing these and looking at what speedup we get, especially as we go to longer simulation timescales. We might hope that we would start asymtoting to a constant cost per forward-simulated generation.