RobinKa / jaxga

Geometric Algebra package for JAX
MIT License
48 stars 5 forks source link

Memory layout question #2

Open EelcoHoogendoorn opened 2 years ago

EelcoHoogendoorn commented 2 years ago

Thanks for making this library; I wanted to do something similar a few months ago but other things got into the way; awesome to see that the JAX ecosystem is maturing this fast!

One design question I had was wether to go with a struct-of-arrays or array-of-structs memory layout. Unless ive misread, your design is the latter; that is, if I vmap over a multivector, the components of the multivector will form the last axis; which in JAX is the contiguous stride==1 axis. If youd think about how this would get vectorized on a GPU, that wouldnt be ideal; if every thread in a block gets to work on a single element of the vmapped axis, which is the most straightforward parrelelization, now the threads in this warp are not performing contiguous memory accesses. Hence the structs-of-arrays are generally preferred on the GPU. Also if you dig into the deepmind/alphafold repo, you will see that they also use a struct-of-array layout for their vector types and the like.

Now this is all terrible premature optimization as far as the actual goals im trying to achieve; but I guess im trying to form a bit of a deeper understanding of JAX and TPUs on a low level. So with that in mind; was this a deliberate choice, or something that you have given any thought to?

EelcoHoogendoorn commented 2 years ago

Unless ive misread

Taking more than 30 seconds to dig into the source, it seems like I did. Looking at the source of mv_multiply, axis 0 is the mv index axis; so effectively its a struct of arrays.

Either way, my question stands. Is that design decision related to an understanding of how it will map to the GPU/TPU? Or what were your considerations?

RobinKa commented 2 years ago

Hey, I benchmarked both on my GPU and this layout was much faster. I think around 10x for the geometric product, although I don't remember the exact number. (so overall, not based on theory, I just tried it out and looked at what did better)

EelcoHoogendoorn commented 2 years ago

Right; that makes sense then!

Harping on about the same topic; as currently implemented, at least one batch axis appears mandatory; I cant just create a single vector and act on it since it fails internal broadcasting logic. That appears to me as a bug, not a feature, am I right?

Wouldnt it be more JAX-tonic to try and write the logic entirely batch-agnostic, and leave any batching to whatever vmaps might be applied? Or is there a technical reason this design is preferable?

Also; do you think it would be possible/desireable to have the multivector behave like the alphafold types? Rather than being a continguous array-of-arrays, make the class more like an attrdict of e_{} arrays? That way the memory is more fragmented which might seem like a bad thing; but I suspect it isnt, in a typical jax usecase, where we rely on the vmapping for exposing parallelism anyway.

Im worried that the JAX compiler isnt nessarily smart about all the .at[out_index].add type mutations that are going on in the low level ops. Are we sure the compiler isnt copying the entire contiguous array every time an instruction like that happens? For sure there must be scenarios; where you perhaps slice a bivector out of a quaternion or whatever. In case of a dict of tensors, this is a noop as far as the TPU/GPU is concerned; but with a dense format this means a copy is forced, no?

From the perspective of style, im also a bit worried; although im not sure that I should be. Im rather used to having my batch axes in the front.

RobinKa commented 2 years ago

Harping on about the same topic; as currently implemented, at least one batch axis appears mandatory; I cant just create a single vector and act on it since it fails internal broadcasting logic. That appears to me as a bug, not a feature, am I right? Wouldnt it be more JAX-tonic to try and write the logic entirely batch-agnostic, and leave any batching to whatever vmaps might be applied? Or is there a technical reason this design is preferable?

I haven't tried it now but if that's the case that's a bug yes, it should be batch-agnostic.

Also; do you think it would be possible/desireable to have the multivector behave like the alphafold types? Rather than being a continguous array-of-arrays, make the class more like an attrdict of e_{} arrays? That way the memory is more fragmented which might seem like a bad thing; but I suspect it isnt, in a typical jax usecase, where we rely on the vmapping for exposing parallelism anyway.

I assumed doing it as one contiguous array would be more performant but that would need to be benchmarked. I'm new to Jax so I don't know how it would behave.

Im worried that the JAX compiler isnt nessarily smart about all the .at[out_index].add type mutations that are going on in the low level ops. Are we sure the compiler isnt copying the entire contiguous array every time an instruction like that happens? For sure there must be scenarios; where you perhaps slice a bivector out of a quaternion or whatever. In case of a dict of tensors, this is a noop as far as the TPU/GPU is concerned; but with a dense format this means a copy is forced, no?

I benchmarked a lot of different approaches already. I do believe that all the .at.add calls get unrolled in the end and doesn't result in N add calls. The only problem with this is that JIT takes forever for a large number of indices. In #1 I have another approach that uses segment_sum, although this is still not as good as it could be in theory.

From the perspective of style, im also a bit worried; although im not sure that I should be. Im rather used to having my batch axes in the front.

I agree, but the performance difference was so huge that this seems to be the way to go.

EelcoHoogendoorn commented 2 years ago

I haven't tried it now but if that's the case that's a bug yes, it should be batch-agnostic.

I have to manually fix the broadcasting sometimes to get things to work.

# broadcasts as expected    
MultiVector(jnp.ones(1), ((),)) * MultiVector(jnp.ones((4, 1)), ((1,),))
# fails 
MultiVector(jnp.ones(1), ((),)) + MultiVector(jnp.ones((4, 1)), ((1,),))

I assumed doing it as one contiguous array would be more performant but that would need to be benchmarked. I'm new to Jax so I don't know how it would behave.

Same; not speaking from deep experience. Thought the alphafold codebase is quite clean and their docstrings suggest they have given it a lot of thought.

I agree, but the performance difference was so huge that this seems to be the way to go.

Yeah I agree; though I think an attrdict-style would be no less performant; and in that case there wouldn't be an axis to collect all the basis elements; so it could mess with your head broadcasting-wise either.

I suppose it also should help with the compilation speed; I dont think there is anything to 'get clever about' if you are simply doing something like output.e_{foo} = sum(foo_terms), whereas being clever about gathering an arbitrary sequence of seemingly-mutable operations into a single immutable operation sounds like its asking a lot of the compiler.

Note that I dont know if a literal attrdict is the right way to go; it could also be an internal dict from index-tuples to jax.arrays. Something thatd play nice with jax_dataclasses/pytrees and the like would be a must I suppose.

EelcoHoogendoorn commented 2 years ago

Having studied the code of jaxga a bit more, and as discussed on discord. I think registering MultiVector as a pytree would be an objective improvement; and that would allow deprecation of all logic related to batching, currently happening inside the mv class and its functions. Batch axes can then be vmapped in; either to the left or the right, which is much cleaner. Unless im missing something there is no reason to couple the bathing logic to the internals.

That is aside from the question if switching from a contiguous array of basis elements to a dict of basis elements would be an improvement. I still suspect it would lead to cleaner, faster compiling, and potentially more performant code; but its hard to be sure without taking the plunge...

EelcoHoogendoorn commented 2 years ago

Elaborating a bit more; the only place ive noticed where a contiguous array layout would directly benefit, is in _values_mv_dual, or in a scalar multiplication of a multivector; sure, its nice to dispatch that as a single call, as opposed to a for loop. But in most ops, weve got the for loops anyway.

What we should hope for, is for JAX to take expression like values[out_idx] = sum(a[ai] * b[bi] * s for ai, bi, s in ab_idx[out_idx]), and fuse expressions like these into a single device kernel call, rather than actually doing repeated writes to global mem. And the compiler should realize that if we are looping over multiple out_idx, that these are independent subgraphs, and that these kernels can be dispatched in parralel. I think JAX/XLA will already take care of the latter; I doubt it does the former, but I think thats on the roadmap?

In any case... I fear that writing it as chained .at.add calls only makes things harder for the compiler; proving that an arbitrary chain of such calls can be reduced without side effects seems like a very hard problem to me; though perhaps this special case is covered efficiently? Dunno, not talking from a ton of experience here; actually looking at some of the instructions generated / doing some profiling would probably be more useful than me speculating about it... maybe ill ask on the JAX github, what people think is wise in this regard.

EelcoHoogendoorn commented 2 years ago

Frustrating how little public documentation of XLA/JAX compilation seems to exist out there... that or my google-fu is just poor...

I suppose the ideal GPU kernel would be one that reads and writes each input/output only once. So have one thread execute the entire mv-op sequentially; making sure that we vmap to form a contiguous axis for each thread to map to contiguous memory access, and keeping all intermediates in thread local memory. In the scheme I proposed above, each output component would be a seperate kernel, doing fresh reads from global mem for each required input. I suppose there is nothing preventing the compiler from being smart enough that this is a multi-input-multi-output function that can be efficiently fused into a single kernel; but I strongly doubt that happens in practice. I suppose the question of array-vs-dict layout is mostly orthogonal to the question if the compiler will make that optimisation. Would be cool if JAX had some kind of numba-like mini-language for writing your own fused kernels, that would compile to all the relevant backends... one can dream.

Getting the best performance out of low-compute-intensity operations like this is generally all about the memory bandwidth; at least thats the general wisdom on the GPU. Im assuming most of the same things hold on TPU though there isnt much public info about it out there. A GPU can execute many hundereds of multiply-adds for every read from global memory it can pull off.

If its really all about the memory bandwidth (and i dont really doubt it thb), then encoding the products as dense products is probably the way to go. In that case, the array representation probably is the way to go; in theory it still shouldnt matter if we store our vectors in a dict or a contiguous array; but in practice we have a good shot of optimized matrix-mul libraries just 'doing the right thing' out of the box, I think. If we just throw the whole dense Cayley table at it using a matrix-multiply, then vmap over that... if you are working in any reasonable-dimension space where len(indices) < 100, that might just be the fastest option. Even on modern CPUs that logic goes quite far actually...

EelcoHoogendoorn commented 2 years ago

Its ofc not actually a straight up matrix-multiply we are after, but something of the form einsum('i, ij, j', a, signs, b)... and then vmapping that; so how well this is going to work really depends on how clever the JAX einsum people have been; bit if a science in itself to get that performing well in the general case. But worth a try. Im thinking about a design where you could ave a multivector, where its easy to swap out the value-backing store, without adding too many lines of code.

EelcoHoogendoorn commented 2 years ago

Wait im being an idiot; im missing the Cayley table there; so itd be something like einsum('i, ijk, j ->k', a, C, b), with C a 3d tensor with a single sign value on every row/col/whatever... so a very sparse operation indeed. Still... for a +++0 style PCG, and only acting on the input/output indices actually present in each vector, it shouldnt be too bad? Most ops are probably effectively something like a dual-quat style multiply in a 3d PCG...

EelcoHoogendoorn commented 2 years ago

Just wrote a script to generate the 3-dim tensor C above for a simple quat:

[[[ 1  0  0  0]
  [ 0  1  0  0]
  [ 0  0  1  0]
  [ 0  0  0  1]]

 [[ 0  1  0  0]
  [-1  0  0  0]
  [ 0  0  0  1]
  [ 0  0 -1  0]]

 [[ 0  0  1  0]
  [ 0  0  0 -1]
  [-1  0  0  0]
  [ 0  1  0  0]]

 [[ 0  0  0  1]
  [ 0  0  1  0]
  [ 0 -1  0  0]
  [-1  0  0  0]]]

Looking good; not too terribly sparse.

But upon seeing this I figured im reinventing the wheel because ive seen this before... and sure enough: https://github.com/deepmind/alphafold/blob/be37a41d6f83e4145bd4912cbe8bf6a24af80c29/alphafold/model/quat_affine.py#L58

So the alphafold guys seem to think that such a dense product is the way to go for simple quats at least. Using my script I can generate these tables for any mv->mv product; but just because you can doesnt mean its wise. But I suspect it will turn out to be surprisingly wise up to fairly high dimension.

EelcoHoogendoorn commented 2 years ago
    alg = Algebra((1, 1, 1, 0))
    E = alg.elements_by_grade
    even = E[0] + E[2] + E[4]  
    i = [alg.elements.index(e) for e in even]
    print(alg.sparse_cayley[i][:, i][:, :, i])   # shape [8, 8, 8]

This should give the table for the general Motor multiplication in 3d PGA. Ill spare you the wall of zeros... but the general pattern is quite simple; at most one nonzero per row; for an overall sparsity of about 10%.

Now I dont know exactly what TPUs are made of, but 8x8x8 mul-adds per 8 memory reads and 8 memory writes, isa compute intensity of only 32 ops per memory access. Any GPU would still be positively bored with those numbers. So I think that still argues in favor of this dense approach.

EelcoHoogendoorn commented 2 years ago

Note its also interesting how alphafold does the reduction; no einsum; just straight up multiplying and summing. https://github.com/deepmind/alphafold/blob/be37a41d6f83e4145bd4912cbe8bf6a24af80c29/alphafold/model/quat_affine.py#L153 Either they are making a huge mistake there... it certainly would be in plain numpy, but whats more likely I think is that JAX knows perfectly fine what to do with expressions like these.

RobinKa commented 2 years ago

In my TensorFlow library https://github.com/RobinKa/tfga I chose the dense approach with the 3-tensor with lots of zeros (also see slides here https://tfgap.warlock.ai/#/6/1 about this approach). It was faster than using sparse matrices or using a lot of indexing on GPU, but I assume it is slower than the approach here with JITting. Also it doesn't scale well for higher dimensions.

EelcoHoogendoorn commented 2 years ago

Ah, cool to see you already looked at this before. Indeed will be interesting to see how the jitting stacks up; im pretty sure there will be a breakeven point, but where it will be I do not dare say. The fact that the alphafold guys use it for quats suggests the dense approach has some merit in JAX too; ideally thered be two subclasses of multivector; one with values: Dict[int, jnp.array], and the other with dense backing storage, I think.

I suppose im about to be able to do some benchmarking... will let you know if something interesting comes out of it.

EelcoHoogendoorn commented 2 years ago

Some preliminary findings:

Doing a AB~A motor sandwich in +++0, boils down to an einsum equivalent of 'i,j,k,ijkl->l', with a 4d tensor of shape [8,8,8,8], and a sparsity of ±0.04. I think thats one of the larger/ more sparse operations you are likely to encounter 'in practice'.

Yet on CPU, as vmapped over 1024 elements, the dense implementation is 8x faster than the sparse implementation. CPU does not seem to care much for left or right vmapping. I dont have the time to test on GPU right now; but im suspecting the vmapping does matter there; and the dense/sparse performance gap to be bigger too.

This is all on contiguous (vmapped) arrays by the way; no lists or dicts of arrays. With numbers like these in favor of dense operations, anything but contiguous arrays is dead in the water, id say.

Note that it shouldnt be hard to dynamically switch to sparse execution, for operations in higher dimensions where sparsity would really get out of hand.

Pretty confident in the correctness and optimality of both implementations.

EelcoHoogendoorn commented 2 years ago

Doing a full blown 16 component +++0 multivector sandwich is 'only' 3x slower in sparse mode. Not sure there is much of a use case for that though; but we are getting close to break even; at least on CPU.

EelcoHoogendoorn commented 2 years ago

So yeah; turns out all the above is nonsense. With fewer bugs in my code, I do get a consistent advantage for the unrolled/sparse approach; at least on CPU. Interestingly, einsum is a lot faster than the broadcast sum (as per the alphafold codebase). Also I see your point about compilation speed, unrolling the loops is like 10x slower in that respect.