Open zoj613 opened 1 year ago
Profiling graphs with Scan
s (or any other Op
s that hold Aesara graphs inside them) is currently complicated by the way in which their inner-graphs are compiled. Simply put, the inner-graphs of Scan
s are compiled when a thunk is created for an Apply
node that uses said Scan
(see here).
This means that all the rewriting and compilation/thunk creation for inner-graphs occurs separately and in a weird lazy, nested fashion (e.g. imagine Scan
s containing other Scan
s, as is the case here). These profiling results only show one layer of that compilation stack, but the total times reflect all the rewriting and compilation in the stack.
https://github.com/aesara-devs/aesara/pull/824 attempts to "flatten" all this out so that—for instance—all the timing results can be presented coherently in one profile. Also, there are a few big rewriting issues that are caused by this undesirable nesting, and https://github.com/aesara-devs/aesara/pull/824 attempts to address those. I have some updates to push there, but we may want/need to move our focus to those changes in order to address this issue more easily (or at all).
In the meantime, as the output profile output implies, we might need to hack up a means of setting profile=True
in all the Scan
s constructed in this example. With that, we can un-nest the profile results and try to figure out exactly where all the time is spent.
The top layer profile output you provided definitively tells us that linking/thunk creation/compilation is slow. Now, we need to know if that's due to compilation using the C backend and/or rewriting. We can use config.profile_optimizer
to help distinguish between the two to start.
I've used something like the following in the past to read profiling info on Scan
s:
scan_fn = aesara.function(...)
for node in scan_fn.maker.fgraph.apply_nodes:
if isinstance(node.op, Scan) and hasattr(node.op.fn, "profile"):
node.op.fn.profile.summary()
That will only get first-layer Scan
s, though; one needs to descend into the inner-graphs of each Scan
in order to get them all. https://github.com/aesara-devs/aesara/pull/824 contains a Feature
that tracks all the inner-graphs added to a FunctionGraph
. We can merge that separately and use it if it helps.
Description of your problem or feature request
The warmup test, particularly https://github.com/aesara-devs/aehmc/blob/d54e2d05512d8d3d4aea92b8732854e6794296e8/tests/test_hmc.py#L49 is very slow and this is due to the compilation time of
warmup_fn
.Please provide a minimal, self-contained, and reproducible example.
Please provide the full traceback of any errors. Running the test with profiling turned on produces the following:
Please provide any additional information below.
Versions and main components
python -c "import aesara; print(aesara.config)"
):Details
``` floatX ({'float32', 'float64', 'float16'}) Doc: Default floating-point precision for python casts. Note: float16 support is experimental, use at your own risk. Value: float64 warn_float64 ({'pdb', 'ignore', 'warn', 'raise'}) Doc: Do an action when a tensor variable with float64 dtype is created. Value: ignore pickle_test_value (