fujiisoup / py3nj

Wigner's 3J, 6J, 9J symbols for python
https://py3nj.readthedocs.io/
Apache License 2.0
18 stars 5 forks source link

Integrating JAX in py3nj? #16

Open srijaniiserprinceton opened 3 years ago

srijaniiserprinceton commented 3 years ago

Have you thought about using jax.numpy instead of numpy? I am using py3nj for my research and now I am porting the code to JAX. Do you think that the way the code is written would benefit from being jitted using jax.jit once all the instances of np in the code is converted to jnp?

fujiisoup commented 3 years ago

Hi @srijaniiserprinceton

I've never thought about JAX and have no experience with it. But actually it is something I should study.

What is the benefit of using py3nj in JAX? Use of GPU and autograd? In this aspect, I'm not sure if actually benefits for py3nj. autograd is hopeless as this is an integer problem. Using this on GPU may be also difficult as this is originally a fortran implementation.

py3nj just wraps the original fortran implementation. See this file https://github.com/fujiisoup/py3nj/blob/master/fortran/_wigner.pyf

Does JAX disallow to use a native numpy function? If it does, we maybe able to do the same thing on JAX.

srijaniiserprinceton commented 3 years ago

I see. For some reason I thought you only use native numpy in the winger.py. If it is in FORTRAN then I guess it won't be as straightforward. Else, I was going to suggest that you might be able to change import numpy as np to import jax.numpy as np. From the brief experience that I have with JAX, I think it gives you near-C speedup because it converts the functions (which needs to be a pure function) into a compiled form after which it abandons the python behavior (which is usually slower than C).

The main advantage that I was trying to use is when calling the py3nj.wigner3j function. I was trying to do wig_jax = jax.jit(py3nj.wigner3j). But this was failing since the package used within py3nj is numpy and not jax.numpy.

In any case, just wanted to see what you think about the possibility. Maybe I would write a JAX version of py3nj if I see that getting the wigner-3j's are becoming a real bottleneck in my code.

srijaniiserprinceton commented 3 years ago

Closing the issue :)

srijaniiserprinceton commented 2 years ago

Hi, reopening this thread: We are using py3nj a lot in our code. But now, to speedup the code we are using JAX. To achieve our speedup, we need py3nj (and all the other functions) to be in jax.numpy (or NumPy compatible). Since py3nj uses fortran under-the-hood, it is becoming an obstacle in making the code just-in-time compatible using JAX. Could you please point me to the algorithm OR the part of the code in fortran where you are actually carrying out the computation? I am thinking of coding up the counterpart using jax.numpy so that it can be just-in-time compiled.

fujiisoup commented 2 years ago

Hi @srijaniiserprinceton

We are using py3nj a lot in our code

It's really nice to know:) Thanks.

For example, this line https://github.com/fujiisoup/py3nj/blob/15f179ecc21033022b05e27c681a1512e2b0e604/fortran/drc.f90#L33 is the high-level subroutine and this part https://github.com/fujiisoup/py3nj/blob/15f179ecc21033022b05e27c681a1512e2b0e604/fortran/drc.f90#L93 computes actual 3j symbols.

I assume JAX also provides some mechanisms to call outside fortran codes, but I have no idea sorry.

srijaniiserprinceton commented 2 years ago

Thanks! Also, to what angular degree \ell do you expect the Wigners to be accurate? We go upto around \ell=300.

fujiisoup commented 2 years ago

Here it says https://github.com/fujiisoup/py3nj/blob/15f179ecc21033022b05e27c681a1512e2b0e604/fortran/drc3jj.f#L85-L88 and thus I assume it is accurate even with large angular degree. Maybe you can take a look of the referenced information.

srijaniiserprinceton commented 2 years ago

Thanks!

fujiisoup commented 2 years ago

Let's keep this issue open. I would appreciate if you could post any updates here :+1: