Fixed FutureWarning from https://github.com/Sm00thix/IKPLS/issues/24#issuecomment-2194939864 in relation to #24.
The culprit was accidentally telling JAX's JIT compiler that argnum 1 (argname i) was static. This was in fact not always true.
I ran a few benchmarks with the different JAX implementations after the fix to see if there were any differences in runtime. The runtimes seem unaffected compared to the original benchmarks.
Fixed FutureWarning from https://github.com/Sm00thix/IKPLS/issues/24#issuecomment-2194939864 in relation to #24. The culprit was accidentally telling JAX's JIT compiler that argnum 1 (argname i) was static. This was in fact not always true. I ran a few benchmarks with the different JAX implementations after the fix to see if there were any differences in runtime. The runtimes seem unaffected compared to the original benchmarks.