Right now, CG coefficients are computed not once but thrice at every TP call, and every inhomogeneous IrrepsArray.__mul__ call
I'm not sure to understand the 3x part completely, though I guess it has to do with the flax module internals wrapping elementwise_tensor_product
# test.py
import jax
import jax.numpy as np
import e3nn_jax as e3nn
# some dummy e3-array
a = e3nn.IrrepsArray.zeros("8x0e + 8x1o + 8x2e", (1,))
# some scalar array
n = a.irreps.num_irreps
b = e3nn.IrrepsArray.zeros(f"{n}x0e", (1,))
# __mul__ calls elementwise_tensor_product
for i in range(3):
print(a * b)
Right now, CG coefficients are computed not once but thrice at every TP call, and every inhomogeneous
IrrepsArray.__mul__
callI'm not sure to understand the 3x part completely, though I guess it has to do with the flax module internals wrapping
elementwise_tensor_product
Outputs:
See #71 Thanks!