Open nicolaspi opened 2 hours ago
Thanks for the report. I would guess that the way to improve compilation performance here is to use something like scan
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html in place of a for
loop when building the multi-step function.
I am preparing a PR for that issue. I tried using scan
, but it is not the solution because of the memory footprint.
The model is traced for each
steps_per_execution
steps (using Jax backend), increasing the jit compilation time and memory usage proportionally.This gist demonstrates the issue on
MobileNetV2
: