keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.93k stars 19.46k forks source link

The model is traced for each `steps_per_execution` step (Jax backend) #20411

Open nicolaspi opened 2 hours ago

nicolaspi commented 2 hours ago

The model is traced for each steps_per_execution steps (using Jax backend), increasing the jit compilation time and memory usage proportionally.

This gist demonstrates the issue on MobileNetV2:

steps_per_execution = 1, model retracing count: 1, memory usage: 622MB, compilation overhead 4.11 seconds.
steps_per_execution = 5, model retracing count: 5, memory usage: 1610MB, compilation overhead 8.10 seconds.
steps_per_execution = 10, model retracing count: 10, memory usage: 3185MB, compilation overhead 14.54 seconds.
steps_per_execution = 20, model retracing count: 20, memory usage: 6337MB, compilation overhead 30.39 seconds.
fchollet commented 2 hours ago

Thanks for the report. I would guess that the way to improve compilation performance here is to use something like scan https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html in place of a for loop when building the multi-step function.

nicolaspi commented 2 hours ago

I am preparing a PR for that issue. I tried using scan, but it is not the solution because of the memory footprint.