Open hatemhelal opened 10 months ago
This (jax trace time not IPU compile time) is what I was looking at for water sparse_eri.
If you comment out factorial
, binomial
and some of the gamma functions it traces much faster. My hunch is this is all due to the jax.lax.fori_loop
and jnp.where
.
Wrt factorial; notice that 32!<2^32
=> you can represent at most 32 factorial values in uint32 (and 64 in uint64). Why not precompute those 32 numbers in numpy and then do fact[i]
? Similar trick applies to factorial2. Not sure about binom, this depends on the input numbers; similarly, not sure about the remaining special functions.
If you need larger values than 64! storing as integer will overflow.
We might also find a re-derivation in terms of Gamma rather than factorial (although I see that binomial may not pop out like that). Is there a sympy derivation anywhere?
Looks like culprit is here : https://github.com/graphcore-research/pyscf-ipu/blob/023e11487762731117d73d14bc98b2f74226f800/pyscf_ipu/experimental/special.py#L80
TODO: write vertex which implements this function and see if it works.
(jax trace time not IPU compile time)
I think it is IPU compile time? The long delay happens after this line
WARNING:absl:Compiling _eri_primitives (140311549046512) for 16 args.
2023-10-03 16:46:03.563604: I external/org_tensorflow/tensorflow/compiler/plugin/poplar/xla_client/pjrt/ipu_pjrt_client.cc:506] IPU PJRT client: compiling device Poplar executable.
Reproducer on branch
105-compilation-time-with-vmap
The compilation progress bar shows "graph construction" and I ended up terminating the test case after waiting 10 minutes.