Open MatDag opened 7 months ago
I suspect the issue is that you're not calling block_until_ready
on the first invocation of grad_fun(params)
, and so the subsequent calculations are being asynchronously dispatched while the device is still busy. When I change your compilation run to this:
jax.block_until_ready(grad_fun(params)) # First run for compilation
I see more consistent timing in the first several runs of the benchmark.
Thank you for answering. I tried your solution and got the same results, unfortunately.
I'm unable to reproduce this on a Colab T4 TPU runtime with jax v0.4.20 and the block_until_ready
that I suggested above.
Can you try updating your jax and jaxlib version and see if that affects the result?
Description
Hi, For benchmarking purposes, I need to measure the time spent to compute the gradient of some Flax model w.r.t. the model's parameters. The gradient is jitted, and a first run is performed for compilation. However, when running this code on a GPU, the five first computations are longer than the fifteen last, while computing the same thing. Is there an explanation for that?
Output:
What jax/jaxlib version are you using?
jax v0.4.7, jaxlib v0.4.7+cuda11.cudnn86
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.11, Linux
NVIDIA GPU info