I've found an implementation like this to be more convenient for a few reasons:
It's global, but also thread safe
Allows for user setting the rng key in a contextmanager, like so:
# RNG_KEY has value A here
with xrand.set_key_context(32):
# NOW RNG_KEY has value jrand.PRNGKey(32)
# Do stuff with Key here, e.g.:
# Initialize a model!
module = Linear(32, 32, rng = xrand.next_key())
# Sample something
sampled = jrand.uniform(xrand.next_key(), (3, 14))
call_some_other_function()
# RNG_KEY has value A here.
Example Implementation:
from contextlib import contextmanager
from contextvars import ContextVar
import jax.random as jrand
RNG_KEY = ContextVar("rng_key", default=None)
def ensure_key(rng_key):
if isinstance(rng_key, int):
rng_key = jrand.PRNGKey(rng_key)
return rng_key
This is also an awesome idea! I used context managers but never implemented one myself. Let me learn a bit more about it and get back with a plan to incorporate your suggested improvement.
I've found an implementation like this to be more convenient for a few reasons:
Example Implementation:
RNG_KEY = ContextVar("rng_key", default=None)
def ensure_key(rng_key): if isinstance(rng_key, int): rng_key = jrand.PRNGKey(rng_key) return rng_key
@contextmanager def set_key_context(rng_key): rng_key = ensure_key(rng_key)
token = RNG_KEY.set(rng_key) yield RNG_KEY.reset(token)
def set_key(new_key): RNG_KEY.set(ensure_key(new_key))
def get_key(): return RNG_KEY.get()
def next_key(): return split()
def split(num=1): new_key, sub_key = jrand.split(get_key()) set_key(new_key) if num > 1: sub_key = jrand.split(sub_key, num) return sub_key