QuantEcon / lecture-jax

Lectures on Quantitative Economics Using JAX
https://jax.quantecon.org/
28 stars 4 forks source link

[cake_eating_numerical] Numba time less than JAX #185

Open mmcky opened 3 weeks ago

mmcky commented 3 weeks ago

The lecture cake_eating_numerical was removed in #175 as the jax time was longer than numba. We need to review the implementations in this lecture as a sense check.

A copy of the lecture is here:

cake_eating_numerical.md

and the timings were at the bottom of this preview

https://6661340ed985e83faa0cb785--incomparable-parfait-2417f8.netlify.app/cake_eating_numerical

We would expect jax to outperform numba unless there is a good reason that we should explain.

@kp992 do you have time to look into this lecture?

TODO:

  1. review implementations and confirm why numba is less than jax for execution time
  2. submit a PR updating and re-enabling this lecture
kp992 commented 3 weeks ago

Sure, will take a look.

kp992 commented 3 weeks ago

Hi @mmcky, I checked the difference in timings and the main reason is the that the difference in the algorithms used. JAX is surely optimized to the fullest but the algorithm used by JAX to find the maximum is a brute force approach where as numba uses brent_max function. Its currently unavailable in JAX implementation and so JAX is just using a brute force approach over the grid.

kp992 commented 3 weeks ago

If the brent_max part is available in JAX, we could beat numba in timings.

mmcky commented 3 weeks ago

thanks @kp992 that is really helpful. Algorithms matter :-).

@jstac (Smit) has identified the issue in timings here and there is a good explanation as to why the numba execution is faster than jax. It is a good example of how algorithms matter (just as much as technology). What do you think about making this point in the lecture?