JiaYaobo / fenbux

A Simple Statistical Distribution Library in JAX
https://jiayaobo.github.io/fenbux/
Apache License 2.0
16 stars 0 forks source link

Speed vs Distrax #7

Open adam-hartshorne opened 8 months ago

adam-hartshorne commented 8 months ago

I have run the following MVE versus Distrax (https://github.com/google-deepmind/distrax) and your library doesn't seem to be as fast. I am running this using jax 0.4.23, cuda 12.2, python 3.10 on a GeForce 4090.

It might be worth looking into why.

import timeit

def setup_code():
    return '''
from jax import jit
from jax import random as jr
from fenbux import logpdf
from fenbux.univariate import Normal
import distrax

key = jr.PRNGKey(0)
x_key, y_key, z_key = jr.split(key, 3)

mean = jr.normal(x_key, (1000000, 2))
sd = jr.normal(y_key, (1000000, 2))
y_k = jr.normal(z_key, (1000000, 2))

def febux_test(mean, sd, y_k):
    return logpdf(Normal(mean=mean, sd=sd), y_k).sum()

def distrax_test(mean, sd, y_k):
    return distrax.Normal(loc=mean, scale=sd).log_prob(y_k).sum()

jit_febux_test = jit(febux_test)
jit_distrax_test = jit(distrax_test)
'''

febux_time = timeit.timeit('jit_febux_test(mean, sd, y_k).block_until_ready()',
                           setup=setup_code(), number=1000)

# Timing distrax_test
distrax_time = timeit.timeit('jit_distrax_test(mean, sd, y_k).block_until_ready()',
                             setup=setup_code(), number=1000)

print("Febux Test Time:", febux_time)
print("Distrax Test Time:", distrax_time)

Febux Test Time: 0.10123697502422146 Distrax Test Time: 0.08472020699991845

JiaYaobo commented 8 months ago

Hi @adam-hartshorne thanks for your report, I profile a little on CPU (m1 pro chip), and result in 3.54ms (fenbux) vs 3.57ms (distrax), and then I just digged into jaxpr in your code:

make_jaxpr(jit_febux_test)(mean, sd, y_k), make_jaxpr(jit_distrax_test)(mean, sd, y_k)
截屏2024-01-17 13 36 00 截屏2024-01-17 13 36 19

I think the problem is that the implment of normal_logpdf in dist_math.normal module is not well optimized, and result in more complex jaxpr, so I adjust normal_logpdf code to match distrax, and profile the program on a Titan RTX as below

import timeit
from jax import jit, make_jaxpr
from jax import random as jr
from fenbux import logpdf
from fenbux.univariate import Normal
import distrax

key = jr.PRNGKey(0)
x_key, y_key, z_key = jr.split(key, 3)

mean = jr.normal(x_key, (1000000, 2))
sd = jr.normal(y_key, (1000000, 2))
y_k = jr.normal(z_key, (1000000, 2))

def febux_test(mean, sd, y_k):
    return logpdf(Normal(mean=mean, sd=sd), y_k).sum()

def distrax_test(mean, sd, y_k):
    return distrax.Normal(loc=mean, scale=sd).log_prob(y_k).sum()

jit_febux_test = jit(febux_test)
jit_distrax_test = jit(distrax_test)

%timeit -r 10 jit_febux_test(mean, sd, y_k).block_until_ready()
%timeit -r 10 jit_distrax_test(mean, sd, y_k).block_until_ready()

And results are:

截屏2024-01-17 14 51 42

And keep in mind,fenbux treats their distribution parameters as pytrees, so you can see in fenbux's jaxpr , an extra function tree_map_dist_at is called everytime. Now except this extra function, fenbux's jaxpr is now matching with distrax's as below:

截屏2024-01-17 14 54 46

and

截屏2024-01-17 14 55 09

I'll open a PR to modify normal_logpdf function soon.

And, to compare speed with other libraries such as tensorflow-probability or distrax, fenbux are expected to be always faster if you simply jit the methods of distributions like jit(dist.log_prob) (as I wrote in readme), and if you compare them wrapped in a function exactly like what you did, with simple array/tensor inputs as parameters, fenbux will exactly match the speed with these libraries under same jaxpr-level optimization !

Finally, thanks again for reminding me some functions are not optimized enough to make jaxpr simplest, I'll dig into these next version of fenbux, and profile performance on GPU. Does it make sense?

JiaYaobo commented 8 months ago

https://github.com/JiaYaobo/fenbux/pull/8 optimize the implement of normal_logpdf