lab-cosmo / sphericart

Multi-language library for the calculation of spherical harmonics in Cartesian coordinates
https://sphericart.readthedocs.io/en/latest/
MIT License
71 stars 12 forks source link

Complex SH? #125

Open Chutlhu opened 3 months ago

Chutlhu commented 3 months ago

Dear all, Thank you very much for your amazing library. It works very well, and my collaborators are enthusiastically using it. We would like to ask if you are thinking of implementing also complex spherical harmonics.

cortner commented 3 months ago

It's a rather trivial transformation that should be easy enough to add by a "user". In the Julis code I implemented it here : https://github.com/ACEsuit/Polynomials4ML.jl/blob/main/src/sphericart.jl

Chutlhu commented 3 months ago

thank you! I didn't know! I implemented your solution with a naive for loop, giving the same results as the Scipy implementation.

def lm2idx(l, m):
    assert np.abs(m) <= l
    assert l >= 0
    return m + l + l*l 

sh = sph.SphericalHarmonics(l_max=3, normalized=False)
H_sphericart = sh.compute(x_cart)
Y = H_sphericart + 1j * np.zeros_like(H_sphericart)
for l in range(0, L+1):
    for m in range(1, l+1):
        i_lm_neg = lm2idx(l, -m)
        i_lm_pos = lm2idx(l,  m)
        Ylm_pos = H_sphericart[:, i_lm_pos]
        Ylm_neg = H_sphericart[:, i_lm_neg]
        Y[:, i_lm_pos] = (-1)**m * (Ylm_pos + 1j * Ylm_neg) / np.sqrt(2)
        Y[:, i_lm_neg] =           (Ylm_pos - 1j * Ylm_neg) / np.sqrt(2)

Need to think about a smarter implementation now

Thank you so much

cortner commented 3 months ago

@ceriottm -- in general the question is valid. Should complex SH be supported as well? I don't have too strong a view since I will always work with my wrapper package that provides some additional convenience functionality. But I wanted to ask. Feel free to close this again if your group has no desire for the foreseeable future to support this.

ceriottm commented 3 months ago

When we started this we meant it to be somewhat "opinionated" - for a while I was thinking of not even providing a normalized version and just compute solid harmonics. It is true that this is a relatively trivial "wrapper" so if @Chutlhu is interested in contributing wrappers for all the languages and frontends we support I'd be welcoming. I'd like to also hear @frostedoyster and @Luthaf though.

Chutlhu commented 3 months ago

I could try, but I don't know if I am skilled enough to to this. I worked on a wrapper for JAX that is differentiable (w.r.t. the input coordinates). It requires to pre-compute some variables and masks a priori:

# SPHERICART CPLX - DIFFERENTIABLE
def lm2idx(l, m):
    assert np.all(np.abs(m) <= l)
    assert np.all(l >= 0)
    return m + l + (l*l) 

def sh_r2c_vect(H, i_pos, i_neg, Z, P, N):
    Yp = H[:,i_pos]
    Yn = H[:,i_neg]
    Y_lm = H * Z[None,:] \
        + (-1)**np.abs(sph_m) * (Yp + 1j * Yn) * P[None,:] / np.sqrt(2) \
        +                       (Yp - 1j * Yn) * N[None,:] / np.sqrt(2)
    return Y_lm

sph_n = np.array([l for l in range(L+1) for m in range(-l, l+1)])
sph_m = np.array([m for l in range(L+1) for m in range(-l, l+1)])

# pre-compute index of positive and negative degrees m
i_pos = lm2idx(sph_n, np.abs(sph_m))
i_neg = lm2idx(sph_n, -np.abs(sph_m))
# pre-compute masks
Z = sph_m == 0
P = sph_m > 0
N = sph_m < 0

Y_lm_diff = sh_r2c_vect(sh.compute(x_cart), i_pos, i_neg, Z, P, N)

More details on this jupyter https://colab.research.google.com/drive/1eo09CeWiakTtFFlYHwSfDu8FvJEdqO1z?usp=sharing

A pytorch implementation should be quite trivial

Luthaf commented 3 months ago

I'm happy to provide complex sph, but following the spirit of #98, we should make sure their API follows the rest of the package.

I'm thinking something like this:

# Torch/Numpy
calculator = ComplexSphericalHarmonics(lmax=10)
sph = calculator(R)

# JAX
calculator = ComplexSphericalHarmonics(lmax=10)
sph = compute_complex_spherical_harmonics(calculator, R)
sph = calculator(R)
# Julia
basis = ComplexSphericalHarmonics(10)

# one of
sph = basis(R)
sph = compute(basis, R)

@Chutlhu, do you think the extra masks and indices can be hidden in a "calculator" class?

cortner commented 3 months ago

If you decide to do this, then I'll implement the Julia version.

Chutlhu commented 3 months ago

Dear @Luthaf,

@Chutlhu, do you think the extra masks and indices can be hidden in a "calculator" class?

yes, sure. I updated the jupyter notebook with the class for Numpy using the API you suggested. However, no gradient is currently implemented. The torch version should be quite similar, but I need more time to set up the testing environment to check that the backpropagation works fine.

Regarding the JAX implementation, it seems that it is currently implemented as a function rather than a class. Will this change in future implementations?

Luthaf commented 3 months ago

Will this change in future implementations?

yes, this is this idea for #98. We don't have a lot of time to work on this currently, but it should happen at some point!

ceriottm commented 2 months ago

We (well, @frostedoyster) made an example of wrappers we just merged in #130 . If someone cares enough to implement also forward derivatives and do the same for all frontends, it'll be a quick merge. Meanwhile as soon as I've a minute I'll open a PR to implement #98 , and maybe we can then make a 0.5 release (1.0 is reserved for when we'll have pre-built wheels...)