Atomic orbitals grid grid_AO can be sparsified aggressively with negligible effect on the DFT energy error, at least for <32 atoms cases we tested on. Its sparsity is also proportional to the number of electrons, indicating further gains on larger molecules.
For C6H6, grid_size=45624 and N=66, we get 4 x 45624 x 66 = 12044736 values so 48MB. We could reduce that by 70-90% so to 14-5MB if we would use a sparse representation (values, cols, rows) for the grid.
grid_AO is used solely in exchange_correlation(...) in multiplication with a few dense matrices. The current implementation doesn't exploit the sparsity in grid_AO.
Solution:
Rewrite the code in experimental_pmap_nanoDFT.py such that grid_AO is sparsified.
Write the sparse x dense matrix multiplication implementation in Jax and apply it in exchange_correlation(...).
Example of sparse matmul implementation in Jax
Problem:
Atomic orbitals grid
grid_AO
can be sparsified aggressively with negligible effect on the DFT energy error, at least for <32 atoms cases we tested on. Its sparsity is also proportional to the number of electrons, indicating further gains on larger molecules.For C6H6,
grid_size=45624
andN=66
, we get 4 x 45624 x 66 = 12044736 values so 48MB. We could reduce that by 70-90% so to 14-5MB if we would use a sparse representation (values, cols, rows) for the grid.grid_AO
is used solely inexchange_correlation(...)
in multiplication with a few dense matrices. The current implementation doesn't exploit the sparsity ingrid_AO
.Solution: Rewrite the code in
experimental_pmap_nanoDFT.py
such thatgrid_AO
is sparsified. Write the sparse x dense matrix multiplication implementation in Jax and apply it inexchange_correlation(...)
. Example of sparse matmul implementation in Jax