google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.61k stars 2.7k forks source link

`unsafe_rbg` fold in function is insensitive to the order data is folded in. #21405

Open justinjfu opened 3 months ago

justinjfu commented 3 months ago

Description

The unsafe_rbg fold_in function is not sensitive to the order in which data is folded in.

The underlying cause is that unsafe_rbg derives its key based on key ^ rbg(data), but XOR is a commutative operation. So if a user folds in two values the order will not change the result since key ^ rbg(1) ^ rbg(2) == key ^ rbg(2) ^ rbg(1).

Reproducing example:

for impl in ['rbg', 'unsafe_rbg']:
  print('impl:', impl)
  key = jax.random.key(42, impl=impl)
  print('Original key:', jax.random.key_data(key))

  key_1 = jax.random.fold_in(key, 1)
  key_12 = jax.random.fold_in(key_1, 2)
  print('Foldin 1 -> 2:', jax.random.key_data(key_12))

  key_2 = jax.random.fold_in(key, 2)
  key_21 = jax.random.fold_in(key_2, 1)
  print('Foldin 2 -> 1:', jax.random.key_data(key_21))

Result:

impl: rbg
Original key: [ 0 42  0 42]
Foldin 1 -> 2: [257214496 567757975 257214496 567757975]
Foldin 2 -> 1: [2853785955  313133857 2853785955  313133857]
impl: unsafe_rbg
Original key: [ 0 42  0 42]
Foldin 1 -> 2: [2393909057 2743418786 1382566513 2711092147]
Foldin 2 -> 1: [2393909057 2743418786 1382566513 2711092147]

Notice in the unsafe_rbg case the two derived keys are the same. Note that rbg uses the key derivation logic from threefry which does not have this issue.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.29 jaxlib: 0.4.29 numpy: 1.26.3 python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (fc57f88f007497a4ead0ec8607ac66e1847b02d6)] jax.devices (1 total, 1 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)] process_count: 1 platform: uname_result(system='Linux', node='8f069381872f0396-766a8253c8.borgtask.google.com', release='5.10.0-smp-1101.34.0.0', version='#1 [v5.10.0-1101.34.0.0] SMP @1712273364', machine='x86_64')

justinjfu commented 3 months ago

This is not an urgent issue to fix because it will change the reproducibility of bits produced by the PRNG. Just marking this as a fix we should bundle whenever we make a large breaking change to unsafe_rbg that would modify the bits anyways.