cubed-dev / cubed

Bounded-memory serverless distributed N-dimensional array processing
https://cubed-dev.github.io/cubed/
Apache License 2.0
115 stars 14 forks source link

Consider adopting a stateless PRNG API #509

Open alxmrs opened 1 month ago

alxmrs commented 1 month ago

While I'm not familiar with the Philox pseudo-random number generator (PRNG) in Numpy (it does look well suited to generation in a distributed setting), I think adopting a stateless PRNG API will be useful long-term for cubed. In addition to working in a parallel/distributed setting, Cubed also has to consider how it can best perform computation with vectorization and hardware acceleration (#304, #490).

I'm quite persuaded by the design of Jax's PRNG system that statelessness (if not also splittable). I belive this approach will prove useful in the long-term.

https://github.com/google/jax/blob/main/docs/jep/263-prng.md

FWIW, I believe any ML framework will have to have special cases for randomness, given the constraints of hardware (GPUs/TPUs).

https://pytorch-dev-podcast.simplecast.com/episodes/random-number-generators

tomwhite commented 1 month ago

FWIW, I believe any ML framework will have to have special cases for randomness, given the constraints of hardware (GPUs/TPUs).

I agree. That's why random number generation is not a part of the Array API, and almost certainly won't be: https://github.com/data-apis/array-api/issues/431.

For Cubed I think this means that the random number functions are less fixed than the rest of the API, so I'd be open to changing them or adding new ones if we need to. The main use case is for generating test data, so they can be quite simple.

What do you think we need for JAX? Could we write an implementation of cubed.random.random that delegates to JAX (if the backend array API is JAX) - or do we need to have a different signature?

alxmrs commented 1 month ago

I'm not totally sure what's needed for JAX. For now (running on a single machine with a single device), random is working well enough, given we can convert the arrays. However, I suspect that when the hardware arrangement changes (e.g. multiple GPUs per machine), things could go wrong.

The Array API link points something interesting out to me: Namely, that PyTorch uses the same RNG that you're using here. This leaves me a bit hopeful that the problem I'm anticipating could just work itself out. I bet, though, that anytime random is used on Jax arrays, it will have to be functionalized, and thus require a different signature.

jakirkham commented 3 weeks ago

If you haven't already, would encourage reading the proposed SPEC 7: https://github.com/scientific-python/specs/pull/180