Open samuela opened 4 years ago
You can already userandom_key()
within @parametrized
:
@parametrized
def dropout(inputs):
keep_rate = 1 - rate
keep = random.bernoulli(random_key(), keep_rate, inputs.shape)
return np.where(keep, inputs / keep_rate, 0)
An independent seed
transform as you describe would make sense, if I find time I will factor it out.
Neat, was not aware of that! Yeah I think having a separate transform would be great.
Handling parameters in JAX can get annoying, but what really concerns me even more is handling PRNG keys. JAX has a done a lot of great work to build a very strong PRNG system, but unfortunately splitting and managing random keys can be very messy and especially error-prone. It's alarmingly easy to accidentally reuse a PRNG key. It would be great to have a system analogous to
@parameterized
andparameter()
but for random keys and seeds.I envision an API providing something like
@random
andrng()
:And then ~ magic ~ happens after which point we get a function like: