google / objax

Apache License 2.0
769 stars 77 forks source link

If user won't add random generator to VarCollection of jitted code then same number always return by random generator #208

Closed AlexeyKurakin closed 3 years ago

AlexeyKurakin commented 3 years ago

Here is code example which demonstrates the issue:

def function():
    d = objax.random.uniform((1,))
    return d

function_jit = objax.Jit(function, objax.VarCollection())
for _ in range(3):
    print(function())
print('-------------')
for _ in range(3):
    print(function_jit())

# [0.10536897]
# [0.2787192]
# [0.6866876]
# -------------
# [0.58595073]
# [0.58595073]
# [0.58595073]

The correct fix would be to add DEFAULT_GENERATOR variables to VarCollection:

def function():
    d = objax.random.uniform((1,))
    return d

function_jit = objax.Jit(function, objax.random.DEFAULT_GENERATOR.vars())
for _ in range(3):
    print(function_jit())

# [0.58595073]
# [0.25828767]
# [0.9098333]

The problem here is that it's not obvious for the user that random generator vars should be provided to Jit. There are several possible fixes:

  1. Improve documentation. In this case we still rely on users to carefully read documentation.
  2. Always add DEFAULT_GENERATOR variables to JIT, Parallel, Vectorize. This would add code complexity, possible some performance overhead (though very small) and would only work for default random generator.
  3. Improve Function.auto_vars to automatically trace usage of random generators and recommend users to always use auto_vars. This probably would be the best long term solution, but may take us time to implement.