Open justinjfu opened 3 months ago
This is not an urgent issue to fix because it will change the reproducibility of bits produced by the PRNG. Just marking this as a fix we should bundle whenever we make a large breaking change to unsafe_rbg
that would modify the bits anyways.
Description
The
unsafe_rbg
fold_in function is not sensitive to the order in which data is folded in.The underlying cause is that
unsafe_rbg
derives its key based onkey ^ rbg(data)
, but XOR is a commutative operation. So if a user folds in two values the order will not change the result sincekey ^ rbg(1) ^ rbg(2) == key ^ rbg(2) ^ rbg(1)
.Reproducing example:
Result:
Notice in the
unsafe_rbg
case the two derived keys are the same. Note thatrbg
uses the key derivation logic from threefry which does not have this issue.System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.29 jaxlib: 0.4.29 numpy: 1.26.3 python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (fc57f88f007497a4ead0ec8607ac66e1847b02d6)] jax.devices (1 total, 1 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)] process_count: 1 platform: uname_result(system='Linux', node='8f069381872f0396-766a8253c8.borgtask.google.com', release='5.10.0-smp-1101.34.0.0', version='#1 [v5.10.0-1101.34.0.0] SMP @1712273364', machine='x86_64')