Our models can't be compiled right now. Probably the logic related to checking if values are in the cache and stuff like that is causing what PyTorch people call a "graph break" as we need to execute python code between operations.
To avoid graph breaks, I believe we have to go around memoization/recursion and actually write a loop?
source code generation
But writing out the model as a loop would really suck. So we will have to generate the loop.
It would maybe be easier to just generate the unrolled loop, but 22 formulas across 277 timesteps would be like 6000 lines of code. And you can't really edit that by hand in any productive way, so we will probably have to actually write the loop just for ergonomics.
constraints
I don't want to deal with things that aren't single integer parameters, will probably enforce that.
implementation details
cache_graph.graph is currently unused. Probably use that sort of thing.
Enforce that timesteps are never going back more than 1. Enforce that only args is ever timesteps.
algorithm:
Check for data dependencies to t-1. All functions which are ever called as func(t-1) go into the t_prev_list.
Collect the graph for t=0, topological sort, source to source compile, ending with func_t_prev = func_t for all func in the t_prev_list.
t=0 is handled separately because of if t == 0 initialization conditions on pols_if.
Collect the graph for t=1, sort, compile.
calls to functions at time t-1 are going to reference func_t_prev.
We expect no timestep related conditionals to be in play here as with t=0.
The whole function can be parameterized by t and that will determine the number of iterations in the loop or something like that? At the end of the day, the results of the compiler will be like this:
class MyClass:
def __init__():
# same code as before
mp = ...
...
def run(max_t: int):
pols_if = mp.pols_if_init
pols_death = pols_if * assume.mort_rate
pols_if_prev = pols_if
pols_death_prev = pols_death
for _ in range(max_t):
pols_if = mp.pols_if_init
pols_death = pols_if * assume.mort_rate
pols_if_prev = pols_if
pols_death_prev = pols_death
might not be a productive use of time to read this, be warned
motivation
memory bottlenexk
Because everything is elementwise and there is no big matmul the models are memory limited on GPU and not utilizing all the FLOPS. https://docs.nvidia.com/deeplearning/performance/dl-performance-gpu-background/index.html#element-op
use the AI compiler
To get around the bottleneck we can use the AI compiler
AI compiler doesn't work with memoization
Our models can't be compiled right now. Probably the logic related to checking if values are in the cache and stuff like that is causing what PyTorch people call a "graph break" as we need to execute python code between operations.
To avoid graph breaks, I believe we have to go around memoization/recursion and actually write a loop?
source code generation
But writing out the model as a loop would really suck. So we will have to generate the loop.
It would maybe be easier to just generate the unrolled loop, but 22 formulas across 277 timesteps would be like 6000 lines of code. And you can't really edit that by hand in any productive way, so we will probably have to actually write the loop just for ergonomics.
constraints
I don't want to deal with things that aren't single integer parameters, will probably enforce that.
implementation details
cache_graph.graph is currently unused. Probably use that sort of thing. Enforce that timesteps are never going back more than 1. Enforce that only args is ever timesteps.
algorithm:
func(t-1)
go into thet_prev_list
.func_t_prev = func_t
for allfunc
in thet_prev_list
.if t == 0
initialization conditions on pols_if.Collect the graph for
t=1
, sort, compile.t-1
are going to referencefunc_t_prev
.The whole function can be parameterized by
t
and that will determine the number of iterations in the loop or something like that? At the end of the day, the results of the compiler will be like this: