google-research / dex-lang

Research language for array processing in the Haskell/ML family
BSD 3-Clause "New" or "Revised" License
1.58k stars 106 forks source link

`Random` effect #465

Open danieldjohnson opened 3 years ago

danieldjohnson commented 3 years ago

Dex's random number generation is currently quite similar to JAX's. The implementation is a great fit for a functional language, but it's a bit annoying to have to manually split keys.

If Dex had a built-in Random effect, we could make things much more ergonomic while still compiling to the same implementation. Suppose we exposed

runRandom : Key -> (() -> {Random} a) -> a
getKey : () -> {Random} Key

Then, we could desugar this to an efficient functional RNG splitting process:

Note that you can "duplicate" keys by doing getKey and then calling an inner runRandom twice with the same action. So I think this would be equally powerful as manually passing keys, but much more convenient.

danieldjohnson commented 3 years ago

Open question: are there other types of effect that act like this? In other words, would it be better to implement some sort of effect Splittable a and then provide an implementation of Splittable Key?

srush commented 3 years ago

Relevant discussion as well https://github.com/google-research/dex-lang/issues/401

(Btw, that response mentions a Monadic implementation, but we don't have a Monad class in the prelude nor does the current type system support it. Might be a nice target of https://github.com/google-research/dex-lang/issues/460 ).

dougalm commented 3 years ago

Yeah, I really hope we can solve this with effects! As I mention in that comment on #401, my concern is with parallelism. Under a for, we want to use ixkey-based splitting instead of state, as you say. But that means that a for isn't equivalent to the unrolled version. But maybe that's fine?

Sasha, sorry for confusing things by bringing up monads. I wasn't suggesting we actually use them. I was just using that standard interface as a way to describe a possible PRNG API. Haskell's Monad type class gives you a clear set of functions to implement and laws those functions should satisfy when you want to implement a new instance. The problem with Dex's effects is that they're all compiler built-ins, and there isn't a clear set of steps for defining a new one. (On the plus side, that does give us a lot of freedom.)