Open adam-hartshorne opened 8 months ago
Hi @adam-hartshorne thanks for your report, I profile a little on CPU (m1 pro chip), and result in 3.54ms (fenbux) vs 3.57ms (distrax), and then I just digged into jaxpr in your code:
make_jaxpr(jit_febux_test)(mean, sd, y_k), make_jaxpr(jit_distrax_test)(mean, sd, y_k)
I think the problem is that the implment of normal_logpdf
in dist_math.normal
module is not well optimized, and result in more complex jaxpr, so I adjust normal_logpdf
code to match distrax
, and profile the program on a Titan RTX as below
import timeit
from jax import jit, make_jaxpr
from jax import random as jr
from fenbux import logpdf
from fenbux.univariate import Normal
import distrax
key = jr.PRNGKey(0)
x_key, y_key, z_key = jr.split(key, 3)
mean = jr.normal(x_key, (1000000, 2))
sd = jr.normal(y_key, (1000000, 2))
y_k = jr.normal(z_key, (1000000, 2))
def febux_test(mean, sd, y_k):
return logpdf(Normal(mean=mean, sd=sd), y_k).sum()
def distrax_test(mean, sd, y_k):
return distrax.Normal(loc=mean, scale=sd).log_prob(y_k).sum()
jit_febux_test = jit(febux_test)
jit_distrax_test = jit(distrax_test)
%timeit -r 10 jit_febux_test(mean, sd, y_k).block_until_ready()
%timeit -r 10 jit_distrax_test(mean, sd, y_k).block_until_ready()
And results are:
And keep in mind,fenbux treats their distribution parameters as pytrees, so you can see in fenbux's jaxpr , an extra function tree_map_dist_at
is called everytime. Now except this extra function, fenbux's jaxpr is now matching with distrax's as below:
and
I'll open a PR to modify normal_logpdf
function soon.
And, to compare speed with other libraries such as tensorflow-probability
or distrax
, fenbux
are expected to be always faster if you simply jit
the methods of distributions like jit(dist.log_prob)
(as I wrote in readme), and if you compare them wrapped in a function exactly like what you did, with simple array/tensor inputs as parameters, fenbux
will exactly match the speed with these libraries under same jaxpr-level optimization !
Finally, thanks again for reminding me some functions are not optimized enough to make jaxpr
simplest, I'll dig into these next version of fenbux, and profile performance on GPU. Does it make sense?
https://github.com/JiaYaobo/fenbux/pull/8 optimize the implement of normal_logpdf
I have run the following MVE versus Distrax (https://github.com/google-deepmind/distrax) and your library doesn't seem to be as fast. I am running this using jax 0.4.23, cuda 12.2, python 3.10 on a GeForce 4090.
It might be worth looking into why.
Febux Test Time: 0.10123697502422146 Distrax Test Time: 0.08472020699991845