Open mmcky opened 3 weeks ago
Sure, will take a look.
Hi @mmcky, I checked the difference in timings and the main reason is the that the difference in the algorithms used. JAX is surely optimized to the fullest but the algorithm used by JAX to find the maximum is a brute force approach where as numba uses brent_max
function. Its currently unavailable in JAX implementation and so JAX is just using a brute force approach over the grid.
If the brent_max
part is available in JAX, we could beat numba in timings.
thanks @kp992 that is really helpful. Algorithms matter :-).
@jstac (Smit) has identified the issue in timings here and there is a good explanation as to why the numba execution is faster than jax. It is a good example of how algorithms matter (just as much as technology). What do you think about making this point in the lecture?
The lecture
cake_eating_numerical
was removed in #175 as thejax
time was longer thannumba
. We need to review the implementations in this lecture as a sense check.A copy of the lecture is here:
cake_eating_numerical.md
and the timings were at the bottom of this preview
https://6661340ed985e83faa0cb785--incomparable-parfait-2417f8.netlify.app/cake_eating_numerical
We would expect
jax
to outperformnumba
unless there is a good reason that we should explain.@kp992 do you have time to look into this lecture?
TODO:
numba
is less thanjax
for execution time