google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.32k stars 254 forks source link

Geometric algebra #242

Closed EelcoHoogendoorn closed 1 month ago

EelcoHoogendoorn commented 2 years ago

If youve ever wondered 'shouldnt a library like this be using geometric algebra', but it seemed like quite a bit of work to get that up and running; that should hopefully no longer be an issue: https://github.com/EelcoHoogendoorn/numga

Ive basically made this library with many of the same goals as the brax project itself in mind; I dont know that you will find it useful in any way, but it wouldn't be a coincidence if you did. If nothing else, its an interesting topic to read up on in the evenings, I can promise that much!

erikfrey commented 2 years ago

This is very cool! I'm not too familiar with GA, but I would be super interested in anything that:

Do you think numga might help with either of those?

EelcoHoogendoorn commented 2 years ago

Based on my experience, im pretty sure itd help with the JIT times. It depends on a lot of particulars, but if you toggle this line you can get a sense of the compile time / runtime impact, in a not-too-terribly-synthetic benchmark. Ive often seen differences of up to 10x or even more; though the gap seems to have narrowed as of more recent JAX releases. Havnt done any meaningful benchmarking on GPU/TPU yet and what the runtime impact is there; but anything that can help a lot with the responsiveness of my edit-compile-run feedback loop is quite worthwhile to me. On CPU sparse execution of products seems pretty hands down the fastest runtime option; though on a GPU/TPU it may well be some form of more granular control over which products get executed how would be desirable, to optimize runtime performance (should be easy to add that to the library). For compilation times I imagine dense products will always reign supreme; unless the JAX team decides to get clever and automatically starts trying to unroll small static tensors itself? If they ever start doing such a thing, I hope they put that functionality under a flag...

I dont have hands on experience with generalized coordinates myself... but my first hunch would be that it might be a great match? For instance, if you need to represent a rotation around a single axis, you can just use a single bivector component for that in numga; and you could readily exponentiate that into a quaternion; (or its complex-number-like lower-subspace analogue thereof); and multiply that into your overall state, and have all these algebraic elements interact consistently without having to pull out your hairs about sign conventions, or having to write ad-hoc classes for these things.

But thats just a birds eye impression; again I dont have boots on the ground experience with the pressing issues of generalized coordinates.

EelcoHoogendoorn commented 2 years ago

As an illustration of the above comment about single bivector components; GA kinda readily invites writing code like this. the argument to ().exp() is a single component bivector; and this exponential returns something thats essentially just two components of a quat; or analogous to a complex number; and the subsequent multiplication with the other quaternion-like object, will be performed by a lazily-generated sort-of-quaternion-multiplication specific to those arguments.

In a more classical setting youd just plonk the small rotation in a full quat or rotation matrix because no way you are going to write and maintain special case multiplication logic for that.

Not that the performance arguments matters in this context; but when you are constantly slicing and dicing and recombining degrees of freedom in your inner loops, as i imagine one does with generalized coordinates, that might mesh together really well.

EelcoHoogendoorn commented 2 years ago

Note that numga in its default implementations tries to avoid all issues such as these Or similar issue such as mentioned in the docstring of quat_to_axis_angle

That is, the default implementation of quat_to_axis_angle (which would be bivector = quat.motor_log() in numga) sidesteps trigonometry altogether, and computes logarithms by bisection, through successive square roots of the motor (similar thing for exponentials; in GA youd view something like quat_rot_axis as the exponential of a bivector). In my experience these simple 'brute force' conversion functions have seemed to work very nice as default implementation; you get very uniform numerical accuracy and uncompromised differentiability; and its a tiny amount of branch-free and nan-handling-free code that works in all dimensions.

That being said its not exactly the fastest to compile or run; at least not on CPU in my use cases. Numga allows you to seemlessly override such functionality with optimized implementations tailored to a specific purpose; there is an example here. So in practice youd probably still want to maintain your own trigonometry-based code; but itd make sense to put such optimizations behind a flag, if you want to do a sanity check and see if the blowup you are trying to debug has anything to do with trigonometric freakouts.

erwincoumans commented 2 years ago

Great news @erikfrey!

it bugs me that the position q vector mixes quaternions and angles based on the joint type.

Yes, the state of a prismatic and revolute joint is a single scalar (position for prismatic and angle for revolute), while the state for a spherical joint is represented as quaternion. The number of degrees of freedom for those types of joints are different. You could remove the spherical joint and replace it by concatenating 3 revolute joints, but then you may get into singularities.

@EelcoHoogendoorn do you support 1 DOF (prismatic, revolute) next to 3DOF 3d joint types (spherical, point-to-point) in https://github.com/EelcoHoogendoorn/numga/blob/main/numga/examples/physics/core.py? If not at the moment, how much work would it be (while keeping the code simple)?

The 6D spatial vector algebra used in Featherstone articulated body algorithm (ABA) is closely related to the bivectors in geometric algebra though. See for example page 47 (8 Guide to the Literature) of this paper (https://arxiv.org/pdf/1101.4542.pdf):

The modern legacy of this work is varied. Some modern literature, such as
[Fea07], use spatial vectors to model rigid body motion; these are 6D vectors
equivalent to our bivectors, but developed within a linear algebra framework
reminiscent of [vM24].
EelcoHoogendoorn commented 2 years ago

@EelcoHoogendoorn do you support 1 DOF (prismatic, revolute) next to 3DOF 3d joint types (spherical, point-to-point) in https://github.com/EelcoHoogendoorn/numga/blob/main/numga/examples/physics/core.py? If not at the moment, how much work would it be (while keeping the code simple)?

The physics code in this repo is for demonstration purposes, so there is no roadmap to turn it into something more generic. Though in my nonpublic work I have worked with various generalizations of point-to-point constraints. In my experience (born out under quick and dirty testing), any flat-to-flat constraint is just a single line of code difference; where now there is a line forque = point_l & point_r, you can replace that with forque = plane_l ^ plane_r, or forque = line_l.commutator(line_r); that will give you the least amount of 'forque' that maps one flat onto the other. And similarly there are simple expressions for the shortest line that projects a point on a flat, etc.

As a general observations; the 'classic vector algebra' operations are a strict subset of GA operations; so anything you can do, you can implement verbatim in a GA framework without having to cast your multivectors back to classical vectors or something like that. Though often there is a more ideomatic/elegant way of doing things.

ViktorM commented 1 year ago
  • but we are releasing a v2 of brax soon that can use generalized coordinates

Hi Erik. When the release is going to happen? Any chance before the end of the year?

EelcoHoogendoorn commented 1 year ago

A small update; the timings I have found of compilation and runtime have always fluctuated a lot, between code iterations and jax version and what not. But now that im transitioning from local development to doing 'real' experiments at scale, I find that using 'dense' operations everywhere are still a massive compilation time benefit (5-6x or so); but somewhat to my surprise also seem to carry about a 2x runtime benefit.

Now that I think about it im currently testing in 2d environments where the resulting math operations tend to be denser than in 3d; so there might still be a point to unrolling in 3d; but the compilation differences will be all the more pronounced, once I start testing in 3d.

I think the lowest hanging fruit to test what difference this makes to environments and hardware the Brax team cares about, would be to rewrite quat_mul / vec_quat_mul in a dense product form; these should appear a bunch of times in the core compute graph, so toggling those should have a substantial impact on compile and runtime performance already I think

EelcoHoogendoorn commented 1 year ago

On a tangentially related note; while I am quite pleased with the computation performance I can get in my physics simulations in JAX, using a brax-like paradigm, I do not get the impression I am getting anywhere near the theoretical optimum of my GPU.

I suppose we should expect to be able to get the most trivial parallelism when every SM/thread processes the update of a single env instance sequentially; in which case we should get a kernel with insane compute intensity, only having to read/write a single env body states at the end of a timestep. Yet I barely see a performance difference between GPUs having a 20x difference in compute; overall I get the impression im ending up with a much more mem-bandwidth limited piece of GPU code, than compute limited. And if our vmapping maps one env to one thread, in the compute-limited case we would expect roughly constant performance, as the number of envs is smaller than the number of SMs; but thats not what the brax paper shows, nor is that my finding; while there is substantial benefit to larger batches, it is hardly as if batch size does not matter under some limit.

What id love to see in the future is some kind of @jax.kernel decorator, which acts as a compiler hint to keep all the stuff within it in a single hardware thread (instead of doing silly things like trying to launch into a different kernel when there is a 4x4 matmul in that code path, for instance); if explicitly putting all logic of a single env into a single hardware thread, vmapping over 10k envs should take exactly as long as running a single env, on a modern GPU.

EelcoHoogendoorn commented 1 year ago

Not sure if mem bandwidth is the limiting factor, or kernel-launch overhead, btw. One thing I can imagine is that for one reason or another JAX is unable to fuse operations as much as youd hope, and the limiting facts ends up being a ton of kernels that need to be launched and properly synced. Insofar that is the case, I suppose renting some older 8-gpu clusters on something like vast.ai would be the most cost effective training option? The original brax paper only looks at multiple TPUs it seems; and there it appears scaling is quite substantially suboptimal; which either points to limits I do not understand, or also clashes with the theory above... definitely a lot of open questions to my mind, as to how to optimize the overall approach.