graphcore-research / pyscf-ipu

PySCF on IPU
https://github.com/graphcore-research/pyscf-ipu#readme
Apache License 2.0
42 stars 2 forks source link

Excessive padding used in `eri_primitives` #117

Open hatemhelal opened 1 year ago

hatemhelal commented 1 year ago

The padding introduced to the c_term here:

https://github.com/graphcore-research/pyscf-ipu/blob/023e11487762731117d73d14bc98b2f74226f800/pyscf_ipu/experimental/integrals.py#L215

could instead be derived from the input primitives, as long as we are careful to only use vmap over multiple primitives of the same total angular momentum (e.g. evaluate the ERI by shell)

This issue will be used to investigate removing the padding used within the eri_primitives -> note that a similar pattern is used in the evaluation of the nuclear attraction integrals which could also be improved.

AlexanderMath commented 1 year ago

Note. For water STO-3G test-case LMAX is 1 instead of 4. This means the sizes of Ci Cj Ck become (4*4+1)^3=3375 instead of (1+1)^3=8.

As we discussed, I think there's a way to circumvent padding to L_MAX in Jax without resorting to C++.

Problem: Different primitives have different L (in our case L=0 for hydrogen and L=1 for oxygen). The resulting (Ci, Cj, Ck) have shapes 1 for L=0 and 3 for L=3. The output of the broadcast Ci Cj Ck can then take shapes (1,1,1), (1,1,3), (1,3,1), (3,1,1), (3,3,1), (3,1,3), ..., (3,3,3).

Current solution: Pad everything L=4. This works but increases memory/compute/?trace time? 400x.

Other solution: Batch together calls with the same shape. Example: do the (1,1,1) calls together, do the (1,1,3), (1,3,1) and (3,1,1) calls together, and so on. For inspiration, this is done here in ~50 lines of Jax.. The (counts,sizes) looks like [(13271, 1), (32711, 3), (57121, 9), ...] which correspond to the cases (1,1,1,1) then (1,1,3,1) and (1,3,1,1) and so on.

@awf Happy to clarify in person. TLDR: Looks like for this case we should be able to get performant Jax code.