zhangxiangxiao / xjax

Simple framework for neural networks using Jax
BSD 3-Clause "New" or "Revised" License
6 stars 2 forks source link

Use contextvar instead of global variable in xrand? #3

Open ryanai3 opened 2 years ago

ryanai3 commented 2 years ago

I've found an implementation like this to be more convenient for a few reasons:

  1. It's global, but also thread safe
  2. Allows for user setting the rng key in a contextmanager, like so:
    # RNG_KEY has value A here
    with xrand.set_key_context(32):
    # NOW RNG_KEY has value jrand.PRNGKey(32)
    # Do stuff with Key here, e.g.:
    # Initialize a model!
    module = Linear(32, 32, rng = xrand.next_key())
    # Sample something
    sampled = jrand.uniform(xrand.next_key(), (3, 14))
    call_some_other_function()
    # RNG_KEY has value A here.

    Example Implementation:

    
    from contextlib import contextmanager
    from contextvars import ContextVar
    import jax.random as jrand

RNG_KEY = ContextVar("rng_key", default=None)

def ensure_key(rng_key): if isinstance(rng_key, int): rng_key = jrand.PRNGKey(rng_key) return rng_key

@contextmanager def set_key_context(rng_key): rng_key = ensure_key(rng_key)

token = RNG_KEY.set(rng_key) yield RNG_KEY.reset(token)

def set_key(new_key): RNG_KEY.set(ensure_key(new_key))

def get_key(): return RNG_KEY.get()

def next_key(): return split()

def split(num=1): new_key, sub_key = jrand.split(get_key()) set_key(new_key) if num > 1: sub_key = jrand.split(sub_key, num) return sub_key

zhangxiangxiao commented 2 years ago

This is also an awesome idea! I used context managers but never implemented one myself. Let me learn a bit more about it and get back with a plan to incorporate your suggested improvement.