tensorflow / tensorflow

An Open Source Machine Learning Framework for Everyone
https://tensorflow.org
Apache License 2.0
184.77k stars 74.12k forks source link

tf.set_random_seed does not reset random op state #9171

Closed msmsajjadi closed 5 years ago

msmsajjadi commented 7 years ago

TF Version: 1.1.0rc1 (installed from nightly: Apr 10, 2017 1:03 AM) (run on CPU, Python 2)

import tensorflow as tf
import numpy as np

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

tf.set_random_seed(1)
a = tf.truncated_normal_initializer(seed=None)([1])
print(a.eval())

tf.set_random_seed(1)
b = tf.truncated_normal_initializer(seed=None)([1])
print(b.eval())

Output:

[ 1.05293429]
[-0.4487586]

Expected: The same value, since...

If the graph-level seed is set, but the operation seed is not: The system deterministically picks an operation seed in conjunction with the graph-level seed so that it gets a unique random sequence.

The values are identical in repeated runs of the whole script, but not after resetting the graph-level seed (as in the example above).

(possibly related to https://github.com/tensorflow/tensorflow/issues/9003)

jart commented 7 years ago

The release notes say determinism is only promised for tf.set_random_seed(0). Does using 0 fix things for you? If not, then please tell me more and I'll re-open.

msmsajjadi commented 7 years ago

Thanks, unfortunately it doesn't seem to work with seed=0 either, tested on the latest nightly from Build #457 (Apr 14, 2017 1:03:00 AM) (nightly whl binary downloaded from the repository, for Mac, CPU-only version, Python 2.)

I get the following output for the above code when using seed 0 in both cases:

[ 1.53692079]
[ 1.01875198]

Is there a reason why it should only be deterministic for that specific seed? Edit: It appears commit e9786df5e89f0345b2eb32d688c7be31c5259ba0 only fixed a specific bug for tf.set_random_seed(0) so it seems unrelated since I've tried various seeds and always get different results for a and b.

msmsajjadi commented 7 years ago

@jart I found the relevant part of the code and I'm suspecting this to be intended behavior. I'll open a new ticket concerning the random seed logic.

jart commented 7 years ago

Or we can reopen this one, as promised. Let me find someone to rope into this conversation while you do some extra investigation.

jart commented 7 years ago

@girving Is the code example in this bug relating to tf.set_random_seed(0) nondeterminism WAI? See: https://github.com/tensorflow/tensorflow/commit/e9786df5e89f0345b2eb32d688c7be31c5259ba0

msmsajjadi commented 7 years ago

Alright!

Simpler, more low-level code example:

for i in range(2):
    tf.set_random_seed(1)
    print(tf.random_uniform([1], seed=None).eval())

Output:

[ 0.77878559]
[ 0.54316521]

Judging from the current implementation, this seems to be intended behavior. In random_seed.py, op_seed is set to ops.get_default_graph()._last_id if the passed parameter for op_seed is None. This value depends on the execution and can change (as it does in the example above). graph_seed is constant as long as it is not set again via set_graph_seed.

The wording in the documentation could be understood either way:

  1. If the graph-level seed is not set, but the operation seed is set: A default graph-level seed and the specified operation seed are used to determine the random sequence.

I think the following behavior would be more natural:

Motivation: There doesn't seem to be an elegant way to have, say, 2 neural network models initialized to the same random values in the same execution graph. Current workarounds include the usage of update_ops to copy values from one network to the other, or to set the op_seeds for all operations which leads to hard-to-maintain code (because the op_seeds would need to be deterministic but different for each layer to ensure that we don't get the same values for all variables).

girving commented 7 years ago

Correct, a seed of None without a graph level seed is intentional nondeterministic behavior. Have you considered setting the seed to a value other than None? If you set two random ops to the same seed, they will produce the same stream of random numbers.

msmsajjadi commented 7 years ago

@girving

Correct, a seed of None without a graph level seed is intentional nondeterministic behavior.

But the graph level seed is set in the example. Fixing the op seeds is possible, but then the user needs to produce pseudo-random op_seeds for all parts of the code to avoid identical random values in the same section (see also the motivation part of my reply above).

In other words, the graph seed is currently only useful for identical results over repetitions of the whole code, while the proposed solution would be more flexible and additionally make the creation of identical values much easier.

girving commented 7 years ago

Ah, your issue is poorly named; fixed. The thing you want is still impossible, since the state you're trying to set is not part of the graph: it is part of the session. Resetting the state on the session would be a nice feature (@langmore wants it too).

The other option is to use random ops with custom seed control, which I am about to add as tf.contrib.stateless.stateless_random_uniform, etc.

girving commented 7 years ago

To set the state on the session and reuse the existing random ops, we would need some plumbing to query the OpKernels for their internal state.

girving commented 7 years ago

Oh, I see: yes, the thing you want to do is possible after all; we'd just have to make the default per-op seeds count from the last time tf.set_random_seed is called.

msmsajjadi commented 7 years ago

That would be great!

girving commented 7 years ago

Well, I didn't volunteer to do it. :)

@michaelisard, @aselle: Do you have thoughts about this? Modifying tf.set_random_seed to reset the per-op seed counter is trivial to do, but it would probably be quite surprising to users when it didn't work to reset states after graph construction is complete.

msmsajjadi commented 7 years ago

What about a dedicated function reset_op_seeds which can only be called during graph construction? Or is it possible to reset those seeds already (manually)?

msmsajjadi commented 7 years ago

It appears @girving's stateless random ops are available in TF 1.2rc0: https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/stateless

This is a useful addition, however it does not simplify the use case of initializing 2 neural network models with identical weights since the weights are generally built via tf.layers/tf.contrib modules which call stateful random-ops. I might be missing an obvious easy solution for this, but I guess stackoverflow is the place for those questions.

Any updates @michaelisard @aselle?

michaelisard commented 7 years ago

@asimshankar assigning to you for triage. It seems as if the current way we set random seeds has missing use cases: do you want to look at future redesigns?

asimshankar commented 7 years ago

I think adding a reset_op_seed parameter to tf.set_random_seed could be acceptable. Something like: tf.set_random_seed(seed, reset_op_seed=False) (I'll ask around if anyone has objections for it defaulting to True).

@ekelsen has been looking at determinism in general and might have other ideas.

@girving : RE your comment - I don't think it would be any more or less surprising for the op-seed reset to not take effect after graph construction is complete than it is for tf.set_random_seed to have no effect after graph construction is complete today. Or am I missing something?

hillst commented 6 years ago

any progress on this? looking for a simple approach to implement dropout as a bayesian representation, https://arxiv.org/pdf/1506.02142.pdf

frthjf commented 6 years ago

I ran into the same issue. However, for my use case I was able to use the stateless randomness:

for i in range(2):
    seed=(1,1)
    print(tf.contrib.stateless.stateless_random_uniform([1], seed).eval())
msmsajjadi commented 6 years ago

@frthjf These new additions are useful, but as mentioned above, they are not being used by the common layer building blocks (yet?), so one would need to pass these manually to all layers (including special initialization schemes).

TYS11 commented 6 years ago

@seaotterman have you found a way to initialize the same weights for Neural Network multiple times over different session?

msmsajjadi commented 6 years ago

@TYS11 I haven't had the time to look into this any further, unfortunately. I agree with @asimshankar that a reset_op_seed parameter could be a decent solution that doesn't break backwards compatibility. Until then, there's likely no better way than to write a custom function that creates 2 neural networks and then copies over the initialization weights from one to the other.

wangpengmit commented 5 years ago

We are addressing this problem by a comprehensive revamp in RFC tensorflow/community#38. Please comment on that.

crobarcro commented 5 years ago

We are addressing this problem by a comprehensive revamp in RFC tensorflow/community#38. Please comment on that.

Just to be clear, does this mean nothing will happen on this issue until v2.0 is released? It really hinders debugging.

wangpengmit commented 5 years ago

We are addressing this problem by a comprehensive revamp in RFC tensorflow/community#38. Please comment on that.

Just to be clear, does this mean nothing will happen on this issue until v2.0 is released? It really hinders debugging.

I'm afraid so.

lostmsu commented 5 years ago

What are my options prior to 2.0? Do I need to create a new graph every time I need to reproduce a result with some fixed seed?

wangpengmit commented 5 years ago

What are my options prior to 2.0? Do I need to create a new graph every time I need to reproduce a result with some fixed seed?

Always providing the 'seed' argument to ops that accept it, along with manually setting the global seed via tf.set_random_seed, can mitigate the problem to some degree.

lostmsu commented 5 years ago

Is there a way to reset all the internal state in the ops to its original value without recreating graph?

wangpengmit commented 5 years ago

Is there a way to reset all the internal state in the ops to its original value without recreating graph?

No, the op's state is stored in a C++ member variable with no access from outside. Exposing it is among the main motivations of the RFC.

wangpengmit commented 5 years ago

I'm closing this since the new RNGs (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/stateful_random_ops.py) are ready.

tensorflow-bot[bot] commented 5 years ago

Are you satisfied with the resolution of your issue? Yes No

huajingyun commented 4 years ago

try to use tf.reset_default_graph()

import tensorflow as tf  
for i in range(2):  
    with tf.Session() as sess:  
        tf.set_random_seed(2)  
        var = tf.Variable(tf.random_normal([1, 1], 0.0, 0.01))  
        init = tf.global_variables_initializer()  
        sess.run(init)  
        print(var.eval())  
    tf.reset_default_graph() 

Output:

[[-0.0142362]]
[[-0.0142362]]