issues
search
google
/
flax
Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.78k
stars
610
forks
source link
[nnx] vectorize vmap split counts
#3989
Closed
cgarciae
closed
2 weeks ago
cgarciae
commented
3 weeks ago
What does this PR do?
Vectorizes split
RngCount
state in
vmap
, this avoids issues that arise when rows perform a variable number of splits (e.g. using
cond
).
Removes
RngKeyBackup
and
RngStream.key_backups
.
What does this PR do?
RngCount
state invmap
, this avoids issues that arise when rows perform a variable number of splits (e.g. usingcond
).RngKeyBackup
andRngStream.key_backups
.