aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 154 forks source link

Change `RandomVariable`-in-`Scan` semantics #898

Open brandonwillard opened 2 years ago

brandonwillard commented 2 years ago

I'm proposing that we make all instances of a RandomVariable constructed within the body of a Scan necessarily draw samples per-iteration of said Scan.

Background

The current behavior of a naive loop with a random variable is demonstrated in the following graph:

import aesara
import aesara.tensor as at

def inner_fn():
    return at.random.normal()

out, updates = aesara.scan(inner_fn, n_steps=10)

fn = aesara.function([], out)
fn()
# array([-0.50043598, -0.50043598, -0.50043598, -0.50043598, -0.50043598,
#        -0.50043598, -0.50043598, -0.50043598, -0.50043598, -0.50043598])

The repeated output is the single sample value of at.random.normal(), which is equivalent to the following, more explicit graph:

X = at.random.normal()

def inner_fn(x):
    return x

out, updates = aesara.scan(inner_fn, non_sequences=[X], n_steps=10)

fn = aesara.function([], out)
fn()
# array([1.51209684, 1.51209684, 1.51209684, 1.51209684, 1.51209684,
#        1.51209684, 1.51209684, 1.51209684, 1.51209684, 1.51209684])

My contention is that the former graph should always produce a distinct sample for each iteration performed by the Scan Op. If one desires a repeated value, as is currently produced, then constructing a graph like the latter is the appropriate approach.

In other words, an expression like at.random.normal() always indicates a distinct sample within the body of a Scan, just as it would in a plain Python loop.

As of now, I don't see a reason for preserving the current semantics that map the first example above to the second. This isn't an issue in any other case, because no other Ops are so opaquely dependent on a state object like RandomVariable is, so, to preserve the common semantics of random sampling, Scan should handle this case specifically and guarantee those semantics.

One Possible Solution

One quick way to accomplish this is to change the RandomVariable.inplace attribute of all RandomVariables in an inner-graph to True—regardless of whether or not their RandomType instances are shared.

This somewhat breaks the consistency of the RNG object/RandomType instance "evolution", which allows one to provide an RNG to a RandomVariable and get a new RNG corresponding to the updated RNG state after sampling. The input and output RNG are supposed to be distinct—unless the in-place optimization/rewrite is performed, in which case they are the same RNG state.

If we cloned the initial RNG states in a Scan inner-graph and performed the iterations in-place on the cloned states, then returned those, I believe that the standard RNG "evolution" semantics would be preserved. Even so, these input/output RNG states are essentially hidden from the user-level and almost never used explicitly (e.g. chaining input/output RNGs results in convoluted graphs and many cloned RNG states at runtime), so the utility of fully preserving these semantics is itself questionable.

Also, Scans do not have a means of returning RNG states, so they've never been able to preserve these semantics faithfully (see #738).

Accomplishing the same thing without in-placing by, for example, the addition of an extra tap for the last RNG output by an earlier iteration, so that it can be used in the next, has little to no foreseeable utility—aside from, say, the case in which one desires the RNG state at every iteration in the form of a Scan output. Regardless, the latter case would require #738 and could easily work alongside the proposed in-place changes (e.g. by copying the RNG state when/if it's an explicit output of the inner-graph).

Related Issues:

ricardoV94 commented 2 years ago

Any reason why make only those RV in scan always inplace and not all of them?

brandonwillard commented 2 years ago

All of the RandomVariables in a graph are made in-place (when allowed by DestroyHandler and the like) in FAST_RUN mode; otherwise, in-placing would break the functional (i.e. no side effects) expectations of Aesara graphs and, specifically, the RNG state evolutions performed by RandomVariables.

As I mentioned above, I don't think those are real concerns if in-place updates are used on a copy of the input RNG state(s) and only in the body of a Scan loop. If Scan returned an output RNG state (and it probably should), I believe the end result would be consistent enough with RandomVariable and its state evolutions.