Open alxmrs opened 1 month ago
Can Cubed’s spec model be extended here?
Cubed has a simple model of upper bounds for memory usage, derived from knowledge about the different operations in the array API. So if there's a way of modelling the memory usage of gradient operations, then this should be possible.
So if there's a way of modelling the memory usage of gradient operations, then this should be possible
I’ve been exploring the jax docs for an answer and I have two ideas so far.
Option one: we can take advantage of the existing jit (or grad?) mechanics to extract array shape information ahead of time (via tracers).
I suspect that grad will produce a trace that will have the array shape and type information. This should provide enough to create memory bounds for Cubed.
Option two: we could create a memory profiling tracer.
https://jax.readthedocs.io/en/latest/autodidax.html
Tracers are Jax-style visitors. I think if we created a generic memory tracer, we could probably use it on grad and non-grad jax programs.
https://github.com/google/jax/blob/694c14bbe6e365f543c7dc67114c8c5e67b5c2df/jax/_src/core.py#L512
Maybe it would be implemented as an AbstractTracer, though grad is concrete. Hmm…
https://jax.readthedocs.io/en/latest/faq.html#different-kinds-of-jax-values https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array
On Wed, Jul 24, 2024 at 9:35 AM Tom White @.***> wrote:
Can Cubed’s spec model be extended here?
Cubed has a simple model of upper bounds for memory usage, derived from knowledge about the different operations in the array API. So if there's a way of modelling the memory usage of gradient operations, then this should be possible.
— Reply to this email directly, view it on GitHub https://github.com/cubed-dev/cubed/issues/518#issuecomment-2247229505, or unsubscribe https://github.com/notifications/unsubscribe-auth/AARXAB555WVPFCFRIH2GGJLZN5RNRAVCNFSM6AAAAABLKFK54WVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENBXGIZDSNJQGU . You are receiving this because you authored the thread.Message ID: @.***>
If you want to do this on top of JAX, I think the easiest approach is probably to build a custom interpreter that implements the array API on top of JAXprs: https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html
JAX primitives are typically thin wrappers around a minimal set of XLA operations, so hopefully this would be relatively straightforward. JAXprs are quite explicit about array shapes, so memory usage should be fairly transparent.
Thanks @shoyer! That's very useful.
It would be awesome if the backing array implementation supported auto differentiation, that we could access some
grad
method from Cubed.It looks like a bunch of stakeholder libraries have this functionality:
https://data-apis.org/array-api/latest/purpose_and_scope.html#stakeholders
Though, differentiable programming may be out of scope for Cubed. @TomNicholas @tomwhite @rbavery any thoughts here?
I have a pipe dream of turning Cubed into an ML framework, and I think this would play an important part.
I haven’t thought of all the implications, but a potential sharp edge that @shoyer once pointed out to me: there will probably be significant memory differences between an op graph and its gradient. Can Cubed’s spec model be extended here?