EelcoHoogendoorn / numga

Geometric algebra in JAX and numpy
MIT License
72 stars 5 forks source link

basic q: how do i construct a high dim multivector? #7

Open alok opened 1 month ago

alok commented 1 month ago

i specifically aspire to take the wedge product of 2 len 50000 vectors. i know that has 2.5e9 entries but i've got the RAM and want to try it for neural network interpretability.

but even at 50, it chokes. I added 5: np.uint64 to bitops.py:21 so it could fit values.

part of my issue is just constructing data in the right format. the closest i found is JaxContext and I haven't figured out its usage. an example is appreciated

code: https://gist.github.com/alok/cfbe4fb7d6e6d2deb288a2634b64aad7

EelcoHoogendoorn commented 1 month ago

Hi,

Numga isnt suited for those ga dimensions. Individual operators are constructed lazily; but all basis blades of the algebra are reasoned about greedily upon initialization of the algebra. So that would be O(2^50000) bytes right there.

There are other packages that claim to push the 64 dim boundary; numga isnt nearly that optimized for those use cases; from my own testing, constructing algebras and operators grinds to a halt somewhere short of 32 dimensions; and as for actually doing computations with any of those algebras, it stop being fun probably around half that number of dimensions again.

ga = numga.backend.jax.context.JaxContext('x+' * 33) 

That tries creating an algebra with 33 basis vectors named x; which wont work.

algebra = numga.algebra.Algebra.from_pqr(26, 0, 0)
ga = numga.backend.jax.context.JaxContext(algebra) 

This notation goes up to 26 auto named elements... not sure if a hard limit crept in right there unintentionally.

mv = ga.multivector(values=rand_1000)

If im reading you correctly (a thousand random 1-vectors), I guess this should be (for a 26-dim algebra, and given that the JaxContext chooses to store the basis blades as the last axis of the underlying backing array)

mv = ga.multivector.vector(values=rand_1000_26)
bivectors = []
for i in range(33):
    for j in range(i + 1, 33):
        bivectors.append(mv[i].wedge(mv[j]))

No idea what the intention is here, frankly. Note that you cant index into the components of a multivector, any more than you can index into the different components of a complex number; indexing is just for slicing subarrays of whole multivectors out of ndarrays of multivectors. If you want to access multivector components, look at the multivector.select / multivector.restrict syntax. But you dont want to access individual components to form a wedge product between them.

EelcoHoogendoorn commented 1 month ago

What I would do as a general rule; start in a low (5 is a good number) dimensional algebra, get things working, and then work yourself up to higher dimensional algebras, and see how much appetite you actually have for it. Figuring out the basics while waiting for 26 dimensional algebras to complete initializing is a miserable way to spend your time id say.

EelcoHoogendoorn commented 1 month ago

Another thought as relating to ML; I dont think youd want to equate neurons/features with multivector components. One might have a nn layer output [batch, 512] complex numbers, or [batch, 512] quaternions, or [batch, 512] fancier multivectors. But id start with modest sized multivectors as each output neuron, and work your way up from there.

Here is an example btw, of numga integrated into a JAX neural network. https://gist.github.com/EelcoHoogendoorn/6be31f076e1ea4d8d1ce197e0b0b3b63