graphcore-research / pyscf-ipu

PySCF on IPU
https://github.com/graphcore-research/pyscf-ipu#readme
Apache License 2.0
41 stars 3 forks source link

Do everything except one operation in float64 #100

Open AlexanderMath opened 10 months ago

AlexanderMath commented 10 months ago

Transform the Jax graph to perform everything in float64 except a set of user-specified operations. May not be possible, we need to think about what that would look like as a Jax graph transformation.

hatemhelal commented 10 months ago

I had the idea to introduce a function decorator that will run an operation twice:

  1. a baseline run with all the floating point inputs promoted to fp64
  2. a second run with fp32

Then report a difference (possibly to stdout or save to a npz file?). What I'm not sure about is how to dynamically annotate a graph to do this but perhaps others have some insight into the jax-way to attempt that.

If that sounds like a useful step I could draft a PR with the decorator.

AlexanderMath commented 10 months ago

I had the idea to introduce a function decorator that will run an operation twice:

Do "operation=nanoDFT" or e.g. "operation=einsum(eri, dm)"? I was thinking that we' run nanoDFT twice, first everything float64, then second where a single decorated operation is in float32. Is this also what you're considering?

hatemhelal commented 9 months ago

I put together #110 which contains just the function decorator idea: the problem I haven't solved is how to inject it into the desired place within a larger compute graph. I think for that we may need a syntax to say "decorate function foo" and compiler pass that does a find and replace on foo in the compute graph.

Or maybe isn't the right way to approach the problem in JAX?