cubed-dev / cubed

Bounded-memory serverless distributed N-dimensional array processing
https://cubed-dev.github.io/cubed/
Apache License 2.0
116 stars 14 forks source link

Optional autodiff support? #518

Open alxmrs opened 1 month ago

alxmrs commented 1 month ago

It would be awesome if the backing array implementation supported auto differentiation, that we could access some grad method from Cubed.

It looks like a bunch of stakeholder libraries have this functionality:

https://data-apis.org/array-api/latest/purpose_and_scope.html#stakeholders

Though, differentiable programming may be out of scope for Cubed. @TomNicholas @tomwhite @rbavery any thoughts here?

I have a pipe dream of turning Cubed into an ML framework, and I think this would play an important part.

I haven’t thought of all the implications, but a potential sharp edge that @shoyer once pointed out to me: there will probably be significant memory differences between an op graph and its gradient. Can Cubed’s spec model be extended here?

tomwhite commented 1 month ago

Can Cubed’s spec model be extended here?

Cubed has a simple model of upper bounds for memory usage, derived from knowledge about the different operations in the array API. So if there's a way of modelling the memory usage of gradient operations, then this should be possible.

alxmrs commented 1 month ago

So if there's a way of modelling the memory usage of gradient operations, then this should be possible

I’ve been exploring the jax docs for an answer and I have two ideas so far.

Option one: we can take advantage of the existing jit (or grad?) mechanics to extract array shape information ahead of time (via tracers).

https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables

I suspect that grad will produce a trace that will have the array shape and type information. This should provide enough to create memory bounds for Cubed.

Option two: we could create a memory profiling tracer.

https://jax.readthedocs.io/en/latest/autodidax.html

Tracers are Jax-style visitors. I think if we created a generic memory tracer, we could probably use it on grad and non-grad jax programs.

https://github.com/google/jax/blob/694c14bbe6e365f543c7dc67114c8c5e67b5c2df/jax/_src/core.py#L512

Maybe it would be implemented as an AbstractTracer, though grad is concrete. Hmm…

https://jax.readthedocs.io/en/latest/faq.html#different-kinds-of-jax-values https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array

On Wed, Jul 24, 2024 at 9:35 AM Tom White @.***> wrote:

Can Cubed’s spec model be extended here?

Cubed has a simple model of upper bounds for memory usage, derived from knowledge about the different operations in the array API. So if there's a way of modelling the memory usage of gradient operations, then this should be possible.

— Reply to this email directly, view it on GitHub https://github.com/cubed-dev/cubed/issues/518#issuecomment-2247229505, or unsubscribe https://github.com/notifications/unsubscribe-auth/AARXAB555WVPFCFRIH2GGJLZN5RNRAVCNFSM6AAAAABLKFK54WVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENBXGIZDSNJQGU . You are receiving this because you authored the thread.Message ID: @.***>

shoyer commented 1 month ago

If you want to do this on top of JAX, I think the easiest approach is probably to build a custom interpreter that implements the array API on top of JAXprs: https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html

JAX primitives are typically thin wrappers around a minimal set of XLA operations, so hopefully this would be relatively straightforward. JAXprs are quite explicit about array shapes, so memory usage should be fairly transparent.

tomwhite commented 1 month ago

Thanks @shoyer! That's very useful.