rwth-i6 / returnn_common

Common building blocks for RETURNN configs, such as models, training concepts, etc
7 stars 4 forks source link

`nn.Random` for multiple ops #148

Closed albertz closed 2 years ago

albertz commented 2 years ago

In principle, a single nn.Random instance can be used in multiple places. In that case, it shares the same random state, which is increased at every op.

On TF side, every time a new random output is asked for, it uses the rng_read_and_skip which is an atomic op (RngReadAndSkip), so this is safe.

However, it is undefined in what order this is executed.

So, the result of this code is non deterministic:

rnd = nn.Random()
result = rnd.normal(()) - rnd.normal(())

While just this is deterministic:

rnd = nn.Random()
result = rnd.normal(())

And this is deterministic as well:

rnd1 = nn.Random()
rnd2 = nn.Random()
result = rnd1.normal(()) - rnd2.normal(())

I assume this is probably not obvious to the user that the first code is non-deterministic.

So, should we do sth about this?

In our earlier code, every usage of the tf.random would anyway have its own separated state, so this was also all fully deterministic. So having nn.Random now and allowing it to be used multiple times causes this problem.

I think we should avoid non-deterministic code, esp if not really needed, and also because we did not have this non-deterministic behavior before.

Disallow nn.Random to be used multiple times? But this would break the param sharing logic, e.g. when some module uses dropout, it would not support to be called multiple times anymore. (Related: #147)

Automatically introduce control dependencies? But this could be tricky in certain cases, e.g. when it's used inside nn.Cond or nn.Loop.

Or nn.Random does not have a single state, but every time it is called it creates a new separate state (thus behaving like tf.random), and this state var is properly assigned as an attribute, like state0, state1, etc.

Other options?

albertz commented 2 years ago

@JackTemaki @Atticus1806 @mmz33 @Zettelkasten opinions?

albertz commented 2 years ago

Or nn.Random does not have a single state, but every time it is called it creates a new separate state (thus behaving like tf.random), and this state var is properly assigned as an attribute, like state0, state1, etc.

I went with this option now.