RobinKa / jaxga

Geometric Algebra package for JAX
MIT License
48 stars 5 forks source link
clifford-algebra geometric-algebra jax machine-learning numerical-computation numpy python

JAXGA - JAX Geometric Algebra

Build status PyPI

GitHub | Docs

JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing only the non-zero basis blade coefficients. It makes use of JAX's just-in-time (JIT) compilation by first precomputing blade indices and signs and then JITting the function doing the actual calculations.

Installation

Install using pip: pip install jaxga

Requirements:

Usage

Unlike most other Geometric Algebra packages, it is not necessary to pre-specify an algebra. JAXGA can either be used with the MultiVector class or by using lower-level functions which is useful for example when using JAX's jit or automatic differentaition.

The MultiVector class provides operator overloading and is constructed with an array of values and their corresponding basis blades. The basis blades are encoded as tuples, for example the multivector 2 e_1 + 4 e_23 would have the values [2, 4] and the basis blade tuple ((1,), (2, 3)).

MultiVector example

import jax.numpy as jnp
from jaxga.mv import MultiVector

a = MultiVector(
    values=2 * jnp.ones([1], dtype=jnp.float32),
    indices=((1,),)
)
# Alternative: 2 * MultiVector.e(1)

b = MultiVector(
    values=4 * jnp.ones([2], dtype=jnp.float32),
    indices=((2, 3),)
)
# Alternative: 4 * MultiVector.e(2, 3)

c = a * b
print(c)

Output: Multivector(8.0 e_{1, 2, 3})

The lower-level functions also deal with values and blades. Functions are provided that take the blades and return a function that does the actual calculation. The returned function is JITted and can also be automatically differentiated with JAX. Furthermore, some operations like the geometric product take a signature function that takes a basis vector index and returns their square.

Lower-level function example

import jax.numpy as jnp
from jaxga.signatures import positive_signature
from jaxga.ops.multiply import get_mv_multiply

a_values = 2 * jnp.ones([1], dtype=jnp.float32)
a_indices = ((1,),)

b_values = 4 * jnp.ones([1], dtype=jnp.float32)
b_indices = ((2, 3),)

mv_multiply, c_indices = get_mv_multiply(a_indices, b_indices, positive_signature)
c_values = mv_multiply(a_values, b_values)
print("C indices:", c_indices, "C values:", c_values)

Output: C indices: ((1, 2, 3),) C values: [8.]

Some notes

Benchmarks

N-d vector * N-d vector, batch size 100, N=range(1, 10), CPU

JaxGA stores only the non-zero basis blade coefficients. TFGA and Clifford on the other hand store all GA elements as full multivectors including all zeros. As a result, JaxGA does better than these for high dimensional algebras.

Below is a benchmark of the geometric product of two vectors with increasing dimensionality from 1 to 9. 100 vectors are multiplied at a time.

JAXGA (CPU) tfga (CPU) clifford
benchmark-results benchmark-results benchmark-results

N-d vector * N-d vector, batch size 100, N=range(1, 50, 5), CPU

Below is a benchmark for higher dimensions that TFGA and Clifford could not handle. Note that the X axis isn't sorted naturally.

benchmark-results