Open certik opened 1 year ago
Curious, I would've expected jax
to be faster given that it executes asynchronously (which should effectively make this line out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)]
parallel, while numpy would execute sequentially since each call is eager and blocking).
Not sure how jax
handles multiple CPUs, I know you can manually set multiple CPUs with the environment var export XLA_FLAGS="--xla_force_host_platform_device_count=8"
, but that didn't yield a speedup for me.
Relevant link: https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy
With #10, I get the following timings with NumPy on my Apple M1 Max:
And Jax:
So Jax is slower. Using htop Jax is using roughly 1.3 CPU cores, while NumPy is using almost 6 CPU cores. Is NumPy automatically parallel on macOS?
Here is my Conda environment: