JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing only the non-zero basis blade coefficients. It makes use of JAX's just-in-time (JIT) compilation by first precomputing blade indices and signs and then JITting the function doing the actual calculations.
Install using pip: pip install jaxga
Requirements:
Unlike most other Geometric Algebra packages, it is not necessary to pre-specify an algebra. JAXGA can either be used with the MultiVector class or by using lower-level functions which is useful for example when using JAX's jit or automatic differentaition.
The MultiVector class provides operator overloading and is constructed with an array of values and their corresponding basis blades. The basis blades are encoded as tuples,
for example the multivector 2 e_1 + 4 e_23
would have the values [2, 4]
and the basis blade tuple ((1,), (2, 3))
.
MultiVector example
import jax.numpy as jnp
from jaxga.mv import MultiVector
a = MultiVector(
values=2 * jnp.ones([1], dtype=jnp.float32),
indices=((1,),)
)
# Alternative: 2 * MultiVector.e(1)
b = MultiVector(
values=4 * jnp.ones([2], dtype=jnp.float32),
indices=((2, 3),)
)
# Alternative: 4 * MultiVector.e(2, 3)
c = a * b
print(c)
Output: Multivector(8.0 e_{1, 2, 3})
The lower-level functions also deal with values and blades. Functions are provided that take the blades and return a function that does the actual calculation. The returned function is JITted and can also be automatically differentiated with JAX. Furthermore, some operations like the geometric product take a signature function that takes a basis vector index and returns their square.
Lower-level function example
import jax.numpy as jnp
from jaxga.signatures import positive_signature
from jaxga.ops.multiply import get_mv_multiply
a_values = 2 * jnp.ones([1], dtype=jnp.float32)
a_indices = ((1,),)
b_values = 4 * jnp.ones([1], dtype=jnp.float32)
b_indices = ((2, 3),)
mv_multiply, c_indices = get_mv_multiply(a_indices, b_indices, positive_signature)
c_values = mv_multiply(a_values, b_values)
print("C indices:", c_indices, "C values:", c_values)
Output: C indices: ((1, 2, 3),) C values: [8.]
get_mv_multiply
and similar functions cache their result by their inputs.JaxGA stores only the non-zero basis blade coefficients. TFGA and Clifford on the other hand store all GA elements as full multivectors including all zeros. As a result, JaxGA does better than these for high dimensional algebras.
Below is a benchmark of the geometric product of two vectors with increasing dimensionality from 1 to 9. 100 vectors are multiplied at a time.
JAXGA (CPU) | tfga (CPU) | clifford |
---|---|---|
Below is a benchmark for higher dimensions that TFGA and Clifford could not handle. Note that the X axis isn't sorted naturally.