tardis-sn / tardis

TARDIS - Temperature And Radiative Diffusion In Supernovae
https://tardis-sn.github.io/tardis
202 stars 405 forks source link

Dead branch pruning #2690

Closed Sumit112192 closed 2 months ago

Sumit112192 commented 3 months ago

:pencil: Description

Type: :roller_coaster: infrastructure

This PR continues on the work of PR #2672 . This PR implements the Dead Branch Pruning method to solve the compile-time constant issue.

@andrewfullard @wkerzendorf

andrewfullard commented 3 months ago

This is a really interesting technique. Can you describe more about how it works in the PR?

Sumit112192 commented 3 months ago

@andrewfullard One of the compilation stages of Numba is type inference. Before type inference, numba prunes branches based on if a variable is None or not none, so that no type inference compilation issue can occur. Two of the numba issues I found useful: 1) https://github.com/numba/numba/issues/4163 2) https://github.com/numba/numba/issues/8603 The dead pruning PR 1) https://github.com/numba/numba/pull/3592

Sumit112192 commented 3 months ago

I am setting the ENABLE_RPACKET_TRACKING flag to None if it's false. And using that info, I am either initializing or not initializing the RPacketTracker.

Sumit112192 commented 3 months ago

Benchmark Results using Scalene Toy code: filename: TARDIS_Master.py/TARDIS_DeadBranchPruning.py

from tardis import run_tardis
sim1 = run_tardis("../tardis_example.yml")
sim2 = run_tardis("../tardis_example.yml")
scalene TARDIS_Master.py/TARDIS_DeadBranchPruning.py

################################################################################################

Upstream Master Branch: image

The total execution time is around 2 min 26 sec

I don't know what this implies, but the montecarlo_main_loop was taking around 3GB of native memory, probably due to rpacket_trackers initialization.

image

#################################################################################################

DeadBranchPruning Branch: image

The total execution time is around 2 min 7 sec

Now, the memory usage by the montecarlo_main_loop has reduced to around 500 Mb

image

wkerzendorf commented 3 months ago

So is the ENABLE_TRACKING now a compile time constant? @sklam do you have insights on this?

Sumit112192 commented 3 months ago

I am unsure if it's a compile-time constant, but I don't have to recompile the code when the value of ENABLE_RPACKET_TRACKING changes.

Sumit112192 commented 3 months ago

Upon playing with code, I realized a serious limitation of this approach

@numba.njit
def aNumaFunc(flag):
    if flag is None:
        variable1 = None
    else:
        variable1 = "a different dataype"
    return variable1

The above would work, But

@numba.njit
def aNumaFunc(flag):
    if flag is None:
        variable1 = 1
    else:
        variable1 = "a different dataype"
    return variable1

This would not.

andrewfullard commented 3 months ago

Given the limitations noted, do you think @Sumit112192 that this will actually resolve our issue with the packet tracker? Or is it just not useful elsewhere in the code.

Sumit112192 commented 3 months ago

Given the limitations noted, do you think @Sumit112192 that this will actually resolve our issue with the packet tracker? Or is it just not useful elsewhere in the code.

If I know correctly, we wanted to completely turn off the tracker when not needed. The present case, which just involves the RPacketTracker, seems to be working.

When there are multiple trackers, say RPacketTracker and RPacketLastInteractionTracker, we can't name both with the same name.

rpacket_trackers = List()
if ENABLE_RPACKET_TRACKING is not None:
    rpacket_trackers = rpacket_trackers.append(RPacketTracker(length))
else:
    rpacket_trackers = rpacket_trackers.append(RPacketLastInteractionTracker())

The above will not work, but we could name them differently

if ENABLE_RPACKET_TRACKING is not None:
    rpacket_trackers = List()
    rpacket_trackers.append(RPacketTracker(length))
else:
    rpacket_trackers = None

if ENABLE_LAST_RPACKET_TRACKING is not None:
    rpacket_last_interaction_trackers = List()
    rpacket_last_interaction_trackers.append(RPacketLastInteractionTracker(length))
else:
    rpacket_last_interaction_trackers = None

The above should probably work (I have to test it out on TARDIS).

Sumit112192 commented 3 months ago

So, if we are OK with different names, this should work for the packet trackers.