Open ywx649999311 opened 1 month ago
hmmm. Are you sure it's the tinygp version, and not the JAX version that makes this difference? Can you try to bisect the version combination where this goes wrong?
I think it starts to go wrong with:
I tried a collection of different JAX versions all the way back to 0.3.3 and it has no major impact on the benchmark. This makes me wonder if the problem could be caused by the jax.debug
introduced in tinygp 0.2.3?
The JAX version has a tiny impact on the performance for about only a factor of <2 at N < 50. The impact starts at about JAX == 0.4.2.
Interesting - thanks for tracking that down! Can try updating the benchmark you're running to use assume_sorted=True
(it should be sufficient to pass that as a keyword argument when building the GaussianProcess
object with the quasisep kernel). If that is set, the debug should never be executed. It seems a little surprising if that debug is having such an effect, but maybe it is. Let's get that figured out!
I think it is more complicated than what I thought earlier. Setting assume_sorted=True
helped, but it is not the cause for the overall factor of ten slowdown. Below are the benchmarks for two version combinations.
jax=0.4.31
and tinygp=0.3.0
:
jax=0.4.33
and tinygp=0.3.0
:
You were right, the JAX version has a much bigger effect! Any thoughts?
P.S. Sorry for the misleading information from above. I tried to make comparisons with the earliest version of tinygp
(0.2.2) which I tested in the past and worked my way up, but I had to stop at jax=0.4.28
because JAX dropped the support for the Python (3.9) version that I used to do the test.
Thanks for tracking this down! My next guess is that this has to do with a change in the CPU behavior that was introduced in v0.4.32
. Can you try setting the following environment variable and trying again:
XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
If that is the source of the issue, I don't know exactly how to solve it in the long term, but at least then we can start digging in. Thanks again!
Setting the flag you recommended does not solve the issue by itself. I have to disable the intra_op_parallelism_threads
flag to make benchmark close to what it was before.
Here are my flag settings:
os.environ["XLA_FLAGS"] = (
os.environ.get("XLA_FLAGS", "")
+ " --xla_cpu_multi_thread_eigen=false"
# + " --intra_op_parallelism_threads=1"
+ " --xla_cpu_use_thunk_runtime=false"
)
Here is the resulting benchmark:
I think you must have set the intra_op_parallelism_threads
for a reason and I am not sure it is reasonable to disable it. Perhaps this is not an apple to apple comparison.
Hi @dfm,
I noticed that there is a factor of 10 increase in the runtime of the quasisep kernels in the newest version of tinygp. So, I rerun your benchmark notebook in the documentation to test out. And indeed, the quasisep kernel is about 10 times slower than celerite2. I am attaching the reproduced benchmark plot. Note I didn't rerun the benchmark for GPU. Any idea what might be causing it?
I run your notebook on two platforms (M2 Mac and Google Colab), the results are basically consistent. The plot shown was produced using the results obtained with:
Thanks for your attention!