graphcore-research / pyscf-ipu

PySCF on IPU
https://github.com/graphcore-research/pyscf-ipu#readme
Apache License 2.0
41 stars 3 forks source link

Investigate long compilation time with `eri_primitives` when using vmap #105

Open hatemhelal opened 10 months ago

hatemhelal commented 10 months ago

Reproducer on branch 105-compilation-time-with-vmap

git checkout 105-compilation-time-with-vmap
export JAX_PLATFORMS=cpu,ipu
export TF_POPLAR_FLAGS="--show_progress_bar=true"
pytest -k test/test_integrals_ipu.py::test_water_eri

The compilation progress bar shows "graph construction" and I ended up terminating the test case after waiting 10 minutes.

AlexanderMath commented 10 months ago

This (jax trace time not IPU compile time) is what I was looking at for water sparse_eri.

If you comment out factorial, binomial and some of the gamma functions it traces much faster. My hunch is this is all due to the jax.lax.fori_loop and jnp.where.

Wrt factorial; notice that 32!<2^32 => you can represent at most 32 factorial values in uint32 (and 64 in uint64). Why not precompute those 32 numbers in numpy and then do fact[i]? Similar trick applies to factorial2. Not sure about binom, this depends on the input numbers; similarly, not sure about the remaining special functions.

If you need larger values than 64! storing as integer will overflow.

awf commented 10 months ago

We might also find a re-derivation in terms of Gamma rather than factorial (although I see that binomial may not pop out like that). Is there a sympy derivation anywhere?

AlexanderMath commented 9 months ago

Looks like culprit is here : https://github.com/graphcore-research/pyscf-ipu/blob/023e11487762731117d73d14bc98b2f74226f800/pyscf_ipu/experimental/special.py#L80

TODO: write vertex which implements this function and see if it works.

awf commented 9 months ago

(jax trace time not IPU compile time)

I think it is IPU compile time? The long delay happens after this line

WARNING:absl:Compiling _eri_primitives (140311549046512) for 16 args.
2023-10-03 16:46:03.563604: I external/org_tensorflow/tensorflow/compiler/plugin/poplar/xla_client/pjrt/ipu_pjrt_client.cc:506] IPU PJRT client: compiling device Poplar executable.