graphcore-research / pyscf-ipu

PySCF on IPU
https://github.com/graphcore-research/pyscf-ipu#readme
Apache License 2.0
42 stars 2 forks source link

Compute "ERI redundancies" on IPU #63

Open AlexanderMath opened 1 year ago

AlexanderMath commented 1 year ago

55 outlines several redundancies in ERI. We want to determine which integrals are above a threshold (to be computed) on the IPU.

Tasks

  1. write python for loop code which take O(N^2) following this post which computes the right thing (hopefully we can do this with 10-20 lines of code in single file, minimize complexity)
  2. jaxify that code
  3. change backend from cpu to ipu and optimize potentially memory layout if needed.
AlexanderMath commented 1 year ago

@awf The following code demonstrates the inequality.

import pyscf
import numpy as np 

mol = pyscf.gto.Mole(atom=[["C", (0,0,0)], ["C", (1,2,3)]])
mol.build()
ERI = mol.intor("int2e_sph")
N = mol.nao_nr()

for a in range(N):
  for b in range(N):
    for c in range(N):
      for d in range(N):
        abcd      = np.abs(ERI[a,b,c,d])
        sqrt_abab = np.sqrt(np.abs(ERI[a,b,a,b]))
        sqrt_cdcd = np.sqrt(np.abs(ERI[c,d,c,d]))

        print(abcd, sqrt_abab*sqrt_cdcd)
        assert abcd <= sqrt_abab*sqrt_cdcd*+1e9 # add 1e-9 atol 

Note: Computing the N^2 entries at compile-time will take <1ms using int2e_sph_cpu.cpp (if we add back the PRAGMA_OMP stuff). Might be useful in certain scenarios.

Note: IPU code int2e_sph.cpp can run DFT for fixed N without recompiling. I think we will be able to do top-2% of integrals without recomputing, that is, compile one graph once which can then do any sparsity pattern with nonzero<= 2% without recompilation (spending flops as if nonzero=2%).

mihaipgc commented 1 year ago

The O(N^2) code from the post looks like this:

import pyscf
import numpy as np 

mol = pyscf.gto.Mole(atom=[["C", (0,0,0)], ["C", (10,2,3)]])
mol.build()
ERI = mol.intor("int2e_sph", aosym="s1")
N = mol.nao_nr()

tolerance = 1e-9

ERI[np.abs(ERI)<tolerance] = 0 
true_nonzero_indices = np.nonzero( ERI.reshape(-1) )[0]
true_nonzero_indices_4d = [np.unravel_index(c, (N, N, N, N)) for c in true_nonzero_indices]

screened_indices_4d = []

# find max value
I_max = 0
for a in range(N):
  for b in range(N):
    abab = np.abs(ERI[a,b,a,b])
    if abab > I_max:
        I_max = abab

# collect candidate pairs for s1
considered_indices = []
for a in range(N):
    for b in range(N):
        abab = np.abs(ERI[a,b,a,b])
        if abab*I_max>=tolerance:
            considered_indices.append((a, b))

# generate s1 indices
for ab in considered_indices:
    a, b = ab
    for cd in considered_indices:
        c, d = cd
        screened_indices_4d.append((a, b, c, d))

print('N', N)
print('I_max', I_max)
print('ERI.reshape(-1).shape', ERI.reshape(-1).shape)
print('len(considered_indices)', len(considered_indices))
print('len(screened_indices_4d)', len(screened_indices_4d))
print('len(true_nonzero_indices_4d)', len(true_nonzero_indices_4d))

check_s1 = [(item in screened_indices_4d) for item in true_nonzero_indices_4d]
assert np.array(check_s1).all()
print('PASSED [(item in screened_indices_4d) for item in true_nonzero_indices_4d]!')

Output:

N 10
I_max 3.5419481332225047
ERI.reshape(-1).shape (10000,)
len(considered_indices) 50
len(screened_indices_4d) 2500
len(true_nonzero_indices_4d) 1468
PASSED [(item in screened_indices_4d) for item in true_nonzero_indices_4d]!

Note: this version and the one above may not "PASS" for any atom configuration, but will likely be close enough; this also means the difference between "screened" and "true" should not be absolute - in other words, a better test might be computing the error vs the real ERI (will do)

mihaipgc commented 1 year ago

But it turns out we can do better, if we integrate the symmetries in the above screening strategy:

import pyscf
import numpy as np 

def get_i_j(val, xnp=np, dtype=np.uint64):
    i = (xnp.sqrt(1 + 8*val.astype(dtype)) - 1)//2 # no need for floor, integer division acts as floor. 
    j = (((val - i) - (i**2 - val))//2)
    return i, j

def c2ijkl(c):
    ij, kl = get_i_j(c)
    i, j = get_i_j(ij)
    k, l = get_i_j(kl)
    return (int(i), int(j), int(k), int(l))

mol = pyscf.gto.Mole(atom=[["C", (0,0,0)], ["C", (10,2,3)]])
mol.build()
ERI = mol.intor("int2e_sph", aosym="s1")
ERI_s8 = mol.intor("int2e_sph", aosym="s8")
N = mol.nao_nr()

tolerance = 1e-9

ERI[np.abs(ERI)<tolerance] = 0 
true_nonzero_indices = np.nonzero( ERI.reshape(-1) )[0]
true_nonzero_indices_4d = [np.unravel_index(c, (N, N, N, N)) for c in true_nonzero_indices]

ERI_s8[np.abs(ERI_s8)<tolerance] = 0
true_nonzero_indices_s8 = np.nonzero( ERI_s8.reshape(-1) )[0]
true_nonzero_indices_s8_4d = [c2ijkl(c) for c in true_nonzero_indices_s8]

screened_indices_s8_4d = []

# find max value
I_max = 0
for a in range(N):
  for b in range(N):
    abab      = np.abs(ERI[a,b,a,b])
    if abab > I_max:
        I_max = abab

# collect candidate pairs for s8
considered_indices = []
for a in range(N):
    for b in range(a, N):
        abab = np.abs(ERI[a,b,a,b])
        if abab*I_max>=tolerance:
            considered_indices.append((a, b)) # collect candidate pairs for s8

# generate s8 indices
for ab in considered_indices:
    a, b = ab
    for cd in considered_indices:
        c, d = cd
        if b<=d:
            screened_indices_s8_4d.append((d, c, b, a))

print('N', N)
print('I_max', I_max)
print('ERI.reshape(-1).shape', ERI.reshape(-1).shape)
print('ERI_s8.shape', ERI_s8.shape)
print('len(considered_indices)', len(considered_indices))
print('len(screened_indices_s8_4d)', len(screened_indices_s8_4d))
print('len(true_nonzero_indices_s8_4d)', len(true_nonzero_indices_s8_4d))

check_s8 = [(item in screened_indices_s8_4d) for item in true_nonzero_indices_s8_4d]
assert np.array(check_s8).all()
print('PASSED [(item in screened_indices_4d) for item in true_nonzero_indices_s8_4d]!')

Output:

N 10
I_max 3.5419481332225047
ERI.reshape(-1).shape (10000,)
ERI_s8.shape (1540,)
len(considered_indices) 30
len(screened_indices_s8_4d) 505
len(true_nonzero_indices_s8_4d) 291
PASSED [(item in screened_indices_4d) for item in true_nonzero_indices_s8_4d]!

This directly computes the list closer to the nonzero distinct ERI values which we are aiming for (same note above on testing applies here as well)