aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 155 forks source link

`Scan` output is shorter with combination of indexing and unused variables #921

Open ricardoV94 opened 2 years ago

ricardoV94 commented 2 years ago

This block returns only 3 values instead of the expected 4

import aesara
import aesara.tensor as at

j = at.constant(0, dtype="int64")
k = at.constant([0], dtype="int64")

# Indexing with unused variable
xs, _ = aesara.scan(
    fn = lambda _: k[0],
    outputs_info=[at.zeros((), dtype="int64")],
    n_steps=4,
)

print(xs.eval())  # [0, 0, 0]

Other combinations return the correct outputs

# Indexing without unused variable
xs, _ = aesara.scan(
    fn = lambda: k[0],
    n_steps=4,
)
print(xs.eval())  # [0, 0, 0, 0]

# Indexing while using previous variable
xs, _ = aesara.scan(
    fn = lambda x: k[0] + x,
    outputs_info=[at.zeros((), dtype="int64")],
    n_steps=4,
)
print(xs.eval())  # [0, 0, 0, 0]

# No indexing with unused variable
xs, _ = aesara.scan(
    fn = lambda _: j,
    outputs_info=[at.zeros((), dtype="int64")],
    n_steps=4,
)
print(xs.eval())  # [0, 0, 0, 0]
brandonwillard commented 2 years ago

Here's a quick investigation:

import aesara
import aesara.tensor as at

j = at.constant(0, dtype="int64", name="j")
k = at.constant([0], dtype="int64", name="k")

# Indexing with unused variable
xs, _ = aesara.scan(
    fn=lambda x: k[0],
    outputs_info=[at.zeros((), dtype="int64")],
    n_steps=4,
)

xs_fn = aesara.function([], xs)
xs_fn()
# array([0, 0, 0])

aesara.dprint(xs)
# Subtensor{int64::} [id A] ''
#  |for{cpu,scan_fn} [id B] ''
#  | |TensorConstant{4} [id C]
#  | |IncSubtensor{Set;:int64:} [id D] ''
#  |   |AllocEmpty{dtype='int64'} [id E] ''
#  |   | |Elemwise{add,no_inplace} [id F] ''
#  |   |   |TensorConstant{4} [id C]
#  |   |   |Subtensor{int64} [id G] ''
#  |   |     |Shape [id H] ''
#  |   |     | |Rebroadcast{(0, False)} [id I] ''
#  |   |     |   |InplaceDimShuffle{x} [id J] ''
#  |   |     |     |Alloc [id K] ''
#  |   |     |       |TensorConstant{0} [id L]
#  |   |     |ScalarConstant{0} [id M]
#  |   |Rebroadcast{(0, False)} [id I] ''
#  |   |ScalarFromTensor [id N] ''
#  |     |Subtensor{int64} [id G] ''
#  |ScalarConstant{1} [id O]
#
# Inner graphs:
#
# for{cpu,scan_fn} [id B] ''
#  >Subtensor{int64} [id P] ''
#  > |k{(1,) of 0} [id Q]
#  > |ScalarConstant{0} [id R]

aesara.dprint(xs_fn)
# Alloc [id A] ''   0
#  |TensorConstant{0} [id B]
#  |TensorConstant{3} [id C]

The Scan was rightly optimized-out; however, the end result is clearly wrong, because the resulting shape (i.e. 3) is incorrect.

For comparison, here's the most similar working example:

# Indexing without unused variable
xs_2, _ = aesara.scan(
    fn=lambda x: k[0] + x,
    outputs_info=[at.zeros((), dtype="int64")],
    n_steps=4,
)

xs_2_fn = aesara.function([], xs_2)
xs_2_fn()
# array([0, 0, 0, 0])

aesara.dprint(xs_2)
# Subtensor{int64::} [id A] ''
#  |for{cpu,scan_fn} [id B] ''
#  | |TensorConstant{4} [id C]
#  | |IncSubtensor{Set;:int64:} [id D] ''
#  |   |AllocEmpty{dtype='int64'} [id E] ''
#  |   | |Elemwise{add,no_inplace} [id F] ''
#  |   |   |TensorConstant{4} [id C]
#  |   |   |Subtensor{int64} [id G] ''
#  |   |     |Shape [id H] ''
#  |   |     | |Rebroadcast{(0, False)} [id I] ''
#  |   |     |   |InplaceDimShuffle{x} [id J] ''
#  |   |     |     |Alloc [id K] ''
#  |   |     |       |TensorConstant{0} [id L]
#  |   |     |ScalarConstant{0} [id M]
#  |   |Rebroadcast{(0, False)} [id I] ''
#  |   |ScalarFromTensor [id N] ''
#  |     |Subtensor{int64} [id G] ''
#  |ScalarConstant{1} [id O]
#
# Inner graphs:
#
# for{cpu,scan_fn} [id B] ''
#  >Elemwise{add,no_inplace} [id P] ''
#  > |Subtensor{int64} [id Q] ''
#  > | |k{(1,) of 0} [id R]
#  > | |ScalarConstant{0} [id S]
#  > |<TensorType(int64, ())> [id T] -> [id D]

aesara.dprint(xs_2_fn)
# forall_inplace,cpu,scan_fn} [id A] ''   2
#  |TensorConstant{4} [id B]
#  |IncSubtensor{InplaceSet;:int64:} [id C] ''   1
#    |AllocEmpty{dtype='int64'} [id D] ''   0
#    | |TensorConstant{4} [id E]
#    |TensorConstant{(1,) of 0} [id F]
#    |ScalarConstant{1} [id G]
#
# Inner graphs:
#
# forall_inplace,cpu,scan_fn} [id A] ''
#  >DeepCopyOp [id H] ''
#  > |<TensorType(int64, ())> [id I] -> [id C]

In this case, the Scan hasn't been optimized-out, but it does look like it should've been, since it's simply looping through its input sequence and copying each value.

Let's try to find the rewrite that's causing this bug:

with aesara.config.change_flags(optimizer_verbose=True):
    xs_fn = aesara.function([], xs)

# optimizer: rewrite push_out_non_seq_scan replaces for{cpu,scan_fn}.0 of for{cpu,scan_fn}(TensorConstant{4}, IncSubtensor{Set;:int64:}.0) with Alloc.0 of Alloc(Subtensor{int64}.0, TensorConstant{4})
# optimizer: rewrite MergeOptimizer replaces TensorConstant{4} of None with TensorConstant{4} of None
# optimizer: rewrite MergeOptimizer replaces TensorConstant{4} of None with TensorConstant{4} of None
# optimizer: rewrite MergeOptimizer replaces TensorConstant{1} of None with TensorConstant{1} of None
# optimizer: rewrite local_add_canonizer replaces Elemwise{sub,no_inplace}.0 of Elemwise{sub,no_inplace}(TensorConstant{4}, Elemwise{switch,no_inplace}.0) with Elemwise{sub,no_inplace}.0 of Elemwise{sub,no_inplace}(TensorConstant{4}, Elemwise{switch,no_inplace}.0)
# optimizer: rewrite constant_folding replaces Elemwise{lt,no_inplace}.0 of Elemwise{lt,no_inplace}(TensorConstant{1}, TensorConstant{4}) with TensorConstant{True} of None
# optimizer: rewrite constant_folding replaces Elemwise{switch,no_inplace}.0 of Elemwise{switch,no_inplace}(TensorConstant{True}, TensorConstant{1}, TensorConstant{4}) with TensorConstant{1} of None
# optimizer: rewrite constant_folding replaces Elemwise{sub,no_inplace}.0 of Elemwise{sub,no_inplace}(TensorConstant{4}, TensorConstant{1}) with TensorConstant{3} of None
# optimizer: rewrite local_subtensor_of_alloc replaces node Subtensor{int64::}(Alloc.0, ScalarConstant{1}) with [Alloc.0]
# optimizer: rewrite LocalOptGroup(local_useless_fill,local_useless_alloc,local_useless_elemwise,local_remove_useless_assert,local_useless_rebroadcast,local_join_1,local_join_empty,local_join_make_vector,local_useless_switch,local_useless_tile,local_useless_split,local_useless_reshape,local_view_op,local_merge_alloc,local_useless_topk,local_useless_SpecifyShape,local_Shape_of_SpecifyShape,local_Shape_i_of_broadcastable,local_Unique_scalar,local_Unique_Alloc_lift,local_Unique_BroadcastTo_lift,local_Unique_Repeat_lift,local_Unique_second,local_remove_scalar_BroadcastTo,local_useless_elemwise_comparison,local_useless_reduce,local_useless_slice,local_subtensor_of_alloc,local_subtensor_make_vector,local_useless_inc_subtensor,local_useless_inc_subtensor_alloc) replaces Subtensor{int64::}.0 of Subtensor{int64::}(Alloc.0, ScalarConstant{1}) with Alloc.0 of Alloc(Subtensor{}.0, Elemwise{sub,no_inplace}.0)
# optimizer: rewrite MergeOptimizer replaces TensorConstant{1} of None with TensorConstant{1} of None
# optimizer: rewrite local_subtensor_merge replaces Subtensor{}.0 of Subtensor{}(Subtensor{int64}.0) with Subtensor{int64}.0 of Subtensor{int64}(k{(1,) of 0}, ScalarConstant{0})
# optimizer: rewrite local_subtensor_remove_broadcastable_index replaces Subtensor{int64}.0 of Subtensor{int64}(k{(1,) of 0}, ScalarConstant{0}) with InplaceDimShuffle{}.0 of InplaceDimShuffle{}(k{(1,) of 0})
# optimizer: rewrite local_add_canonizer replaces Elemwise{sub,no_inplace}.0 of Elemwise{sub,no_inplace}(TensorConstant{4}, Elemwise{switch,no_inplace}.0) with Elemwise{sub,no_inplace}.0 of Elemwise{sub,no_inplace}(TensorConstant{4}, Elemwise{switch,no_inplace}.0)
# optimizer: rewrite constant_folding replaces Elemwise{lt,no_inplace}.0 of Elemwise{lt,no_inplace}(TensorConstant{1}, TensorConstant{4}) with TensorConstant{True} of None
# optimizer: rewrite constant_folding replaces Elemwise{switch,no_inplace}.0 of Elemwise{switch,no_inplace}(TensorConstant{True}, TensorConstant{1}, TensorConstant{4}) with TensorConstant{1} of None
# optimizer: rewrite constant_folding replaces Elemwise{sub,no_inplace}.0 of Elemwise{sub,no_inplace}(TensorConstant{4}, TensorConstant{1}) with TensorConstant{3} of None
# optimizer: rewrite constant_folding replaces InplaceDimShuffle{}.0 of InplaceDimShuffle{}(k{(1,) of 0}) with TensorConstant{0} of None

As we can see, the rewrite push_out_non_seq_scan is the one removing the Scan Op altogether. It produces something like Alloc(Subtensor{int64}.0, TensorConstant{4}), which appears to have the correct shape (i.e. 4).

One of the other rewrites seems be the problem.

brandonwillard commented 2 years ago

No, the problem is somewhere else, because push_out_non_seq_scanis replacing a Scan that's being sliced like scan_out[1:] with a vector of length four, which results in a vector of length three.

It looks like the graph that aesara.scan is constructing is incorrect, because it shouldn't be slicing/Subtensoring anything. There are no taps, and that's when Subtensors are added (e.g. to return only the computed iteration results and not the also the initial tap values).

brandonwillard commented 2 years ago

The example output in #945 says that aesara.scan is interpreting the output_info as a single sit-sot (i.e. "single-input, single-output") input, and sit-sot outputs are always sliced like out[1:] here.

It seems like we need that input to be seen as a nit-sot (i.e. "no-input, single-output").