e3nn / e3nn-jax

jax library for E3 Equivariant Neural Networks
Apache License 2.0
178 stars 18 forks source link

Please wrap `clebsch_gordan` inside `functools.cache` #72

Closed olivier-peltre closed 4 months ago

olivier-peltre commented 4 months ago

Right now, CG coefficients are computed not once but thrice at every TP call, and every inhomogeneous IrrepsArray.__mul__ call

I'm not sure to understand the 3x part completely, though I guess it has to do with the flax module internals wrapping elementwise_tensor_product

# test.py
import jax
import jax.numpy as np

import e3nn_jax as e3nn

# some dummy e3-array
a = e3nn.IrrepsArray.zeros("8x0e + 8x1o + 8x2e", (1,))
# some scalar array
n = a.irreps.num_irreps
b = e3nn.IrrepsArray.zeros(f"{n}x0e", (1,))

# __mul__ calls elementwise_tensor_product
for i in range(3):
    print(a * b)

Outputs:

I'm computing Clebsch-Gordan coefficients!
I'm computing Clebsch-Gordan coefficients!
I'm computing Clebsch-Gordan coefficients!
8x0e+8x1o+8x2e
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
I'm computing Clebsch-Gordan coefficients!
I'm computing Clebsch-Gordan coefficients!
I'm computing Clebsch-Gordan coefficients!
8x0e+8x1o+8x2e
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
I'm computing Clebsch-Gordan coefficients!
I'm computing Clebsch-Gordan coefficients!
I'm computing Clebsch-Gordan coefficients!
8x0e+8x1o+8x2e
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

See #71 Thanks!

mariogeiger commented 4 months ago

Let's talk directly on the PR