Here is code example which demonstrates the issue:
def function():
d = objax.random.uniform((1,))
return d
function_jit = objax.Jit(function, objax.VarCollection())
for _ in range(3):
print(function())
print('-------------')
for _ in range(3):
print(function_jit())
# [0.10536897]
# [0.2787192]
# [0.6866876]
# -------------
# [0.58595073]
# [0.58595073]
# [0.58595073]
The correct fix would be to add DEFAULT_GENERATOR variables to VarCollection:
def function():
d = objax.random.uniform((1,))
return d
function_jit = objax.Jit(function, objax.random.DEFAULT_GENERATOR.vars())
for _ in range(3):
print(function_jit())
# [0.58595073]
# [0.25828767]
# [0.9098333]
The problem here is that it's not obvious for the user that random generator vars should be provided to Jit.
There are several possible fixes:
Improve documentation. In this case we still rely on users to carefully read documentation.
Always add DEFAULT_GENERATOR variables to JIT, Parallel, Vectorize. This would add code complexity, possible some performance overhead (though very small) and would only work for default random generator.
Improve Function.auto_vars to automatically trace usage of random generators and recommend users to always use auto_vars. This probably would be the best long term solution, but may take us time to implement.
Here is code example which demonstrates the issue:
The correct fix would be to add
DEFAULT_GENERATOR
variables toVarCollection
:The problem here is that it's not obvious for the user that random generator vars should be provided to Jit. There are several possible fixes:
DEFAULT_GENERATOR
variables to JIT, Parallel, Vectorize. This would add code complexity, possible some performance overhead (though very small) and would only work for default random generator.Function.auto_vars
to automatically trace usage of random generators and recommend users to always useauto_vars
. This probably would be the best long term solution, but may take us time to implement.