Open spenrich opened 4 years ago
The short answer is that unfortunately at this time XLA GPU is not great at code generation for tight loops like those in odeint
. The body of the while_loop
is compiled into one or more GPU kernels, which has significant launch overhead because control flow goes back to the CPU in each iteration.
@shoyer does anyone work on improving while_loop
?
Yes, there are several ongoing streams of work to improve while_loop
.
@shoyer, apologies for unrelated with the topic questions. Could you share the links to PRs, branches to ongoing work if it's publicly available?
Sorry, I don't have any details that I can share at this time.
On Wed, Dec 9, 2020 at 12:37 PM Artem Artemev notifications@github.com wrote:
@shoyer https://github.com/shoyer, apologies for unrelated with the topic questions. Could you share the links to PRs, branches to ongoing work if it's publicly available?
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/5006#issuecomment-742033766, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVSQQ6RAS2NDDY2AH5DST7NXDANCNFSM4UBHJBMQ .
The following MWE trains a simple neural ODE model with gradient descent to match a 2-D dynamical system (Van der Pol oscillator) with sampled data along a single trajectory. Each iteration of the training loop runs slowly on my GPU when compared to running everything on my CPU (roughly estimated with
tqdm
at 17 iterations/sec on GPU vs. upwards of 800 iterations/sec on CPU).Any first impressions about what might be going on? I can look into doing better profiling if need be.
Versions: jax 0.2.6, jaxlib 0.1.57+cuda102, cuda 10.2