Open ricardoV94 opened 2 years ago
Here's a quick investigation:
import aesara
import aesara.tensor as at
j = at.constant(0, dtype="int64", name="j")
k = at.constant([0], dtype="int64", name="k")
# Indexing with unused variable
xs, _ = aesara.scan(
fn=lambda x: k[0],
outputs_info=[at.zeros((), dtype="int64")],
n_steps=4,
)
xs_fn = aesara.function([], xs)
xs_fn()
# array([0, 0, 0])
aesara.dprint(xs)
# Subtensor{int64::} [id A] ''
# |for{cpu,scan_fn} [id B] ''
# | |TensorConstant{4} [id C]
# | |IncSubtensor{Set;:int64:} [id D] ''
# | |AllocEmpty{dtype='int64'} [id E] ''
# | | |Elemwise{add,no_inplace} [id F] ''
# | | |TensorConstant{4} [id C]
# | | |Subtensor{int64} [id G] ''
# | | |Shape [id H] ''
# | | | |Rebroadcast{(0, False)} [id I] ''
# | | | |InplaceDimShuffle{x} [id J] ''
# | | | |Alloc [id K] ''
# | | | |TensorConstant{0} [id L]
# | | |ScalarConstant{0} [id M]
# | |Rebroadcast{(0, False)} [id I] ''
# | |ScalarFromTensor [id N] ''
# | |Subtensor{int64} [id G] ''
# |ScalarConstant{1} [id O]
#
# Inner graphs:
#
# for{cpu,scan_fn} [id B] ''
# >Subtensor{int64} [id P] ''
# > |k{(1,) of 0} [id Q]
# > |ScalarConstant{0} [id R]
aesara.dprint(xs_fn)
# Alloc [id A] '' 0
# |TensorConstant{0} [id B]
# |TensorConstant{3} [id C]
The Scan
was rightly optimized-out; however, the end result is clearly wrong, because the resulting shape (i.e. 3) is incorrect.
For comparison, here's the most similar working example:
# Indexing without unused variable
xs_2, _ = aesara.scan(
fn=lambda x: k[0] + x,
outputs_info=[at.zeros((), dtype="int64")],
n_steps=4,
)
xs_2_fn = aesara.function([], xs_2)
xs_2_fn()
# array([0, 0, 0, 0])
aesara.dprint(xs_2)
# Subtensor{int64::} [id A] ''
# |for{cpu,scan_fn} [id B] ''
# | |TensorConstant{4} [id C]
# | |IncSubtensor{Set;:int64:} [id D] ''
# | |AllocEmpty{dtype='int64'} [id E] ''
# | | |Elemwise{add,no_inplace} [id F] ''
# | | |TensorConstant{4} [id C]
# | | |Subtensor{int64} [id G] ''
# | | |Shape [id H] ''
# | | | |Rebroadcast{(0, False)} [id I] ''
# | | | |InplaceDimShuffle{x} [id J] ''
# | | | |Alloc [id K] ''
# | | | |TensorConstant{0} [id L]
# | | |ScalarConstant{0} [id M]
# | |Rebroadcast{(0, False)} [id I] ''
# | |ScalarFromTensor [id N] ''
# | |Subtensor{int64} [id G] ''
# |ScalarConstant{1} [id O]
#
# Inner graphs:
#
# for{cpu,scan_fn} [id B] ''
# >Elemwise{add,no_inplace} [id P] ''
# > |Subtensor{int64} [id Q] ''
# > | |k{(1,) of 0} [id R]
# > | |ScalarConstant{0} [id S]
# > |<TensorType(int64, ())> [id T] -> [id D]
aesara.dprint(xs_2_fn)
# forall_inplace,cpu,scan_fn} [id A] '' 2
# |TensorConstant{4} [id B]
# |IncSubtensor{InplaceSet;:int64:} [id C] '' 1
# |AllocEmpty{dtype='int64'} [id D] '' 0
# | |TensorConstant{4} [id E]
# |TensorConstant{(1,) of 0} [id F]
# |ScalarConstant{1} [id G]
#
# Inner graphs:
#
# forall_inplace,cpu,scan_fn} [id A] ''
# >DeepCopyOp [id H] ''
# > |<TensorType(int64, ())> [id I] -> [id C]
In this case, the Scan
hasn't been optimized-out, but it does look like it should've been, since it's simply looping through its input sequence and copying each value.
Let's try to find the rewrite that's causing this bug:
with aesara.config.change_flags(optimizer_verbose=True):
xs_fn = aesara.function([], xs)
# optimizer: rewrite push_out_non_seq_scan replaces for{cpu,scan_fn}.0 of for{cpu,scan_fn}(TensorConstant{4}, IncSubtensor{Set;:int64:}.0) with Alloc.0 of Alloc(Subtensor{int64}.0, TensorConstant{4})
# optimizer: rewrite MergeOptimizer replaces TensorConstant{4} of None with TensorConstant{4} of None
# optimizer: rewrite MergeOptimizer replaces TensorConstant{4} of None with TensorConstant{4} of None
# optimizer: rewrite MergeOptimizer replaces TensorConstant{1} of None with TensorConstant{1} of None
# optimizer: rewrite local_add_canonizer replaces Elemwise{sub,no_inplace}.0 of Elemwise{sub,no_inplace}(TensorConstant{4}, Elemwise{switch,no_inplace}.0) with Elemwise{sub,no_inplace}.0 of Elemwise{sub,no_inplace}(TensorConstant{4}, Elemwise{switch,no_inplace}.0)
# optimizer: rewrite constant_folding replaces Elemwise{lt,no_inplace}.0 of Elemwise{lt,no_inplace}(TensorConstant{1}, TensorConstant{4}) with TensorConstant{True} of None
# optimizer: rewrite constant_folding replaces Elemwise{switch,no_inplace}.0 of Elemwise{switch,no_inplace}(TensorConstant{True}, TensorConstant{1}, TensorConstant{4}) with TensorConstant{1} of None
# optimizer: rewrite constant_folding replaces Elemwise{sub,no_inplace}.0 of Elemwise{sub,no_inplace}(TensorConstant{4}, TensorConstant{1}) with TensorConstant{3} of None
# optimizer: rewrite local_subtensor_of_alloc replaces node Subtensor{int64::}(Alloc.0, ScalarConstant{1}) with [Alloc.0]
# optimizer: rewrite LocalOptGroup(local_useless_fill,local_useless_alloc,local_useless_elemwise,local_remove_useless_assert,local_useless_rebroadcast,local_join_1,local_join_empty,local_join_make_vector,local_useless_switch,local_useless_tile,local_useless_split,local_useless_reshape,local_view_op,local_merge_alloc,local_useless_topk,local_useless_SpecifyShape,local_Shape_of_SpecifyShape,local_Shape_i_of_broadcastable,local_Unique_scalar,local_Unique_Alloc_lift,local_Unique_BroadcastTo_lift,local_Unique_Repeat_lift,local_Unique_second,local_remove_scalar_BroadcastTo,local_useless_elemwise_comparison,local_useless_reduce,local_useless_slice,local_subtensor_of_alloc,local_subtensor_make_vector,local_useless_inc_subtensor,local_useless_inc_subtensor_alloc) replaces Subtensor{int64::}.0 of Subtensor{int64::}(Alloc.0, ScalarConstant{1}) with Alloc.0 of Alloc(Subtensor{}.0, Elemwise{sub,no_inplace}.0)
# optimizer: rewrite MergeOptimizer replaces TensorConstant{1} of None with TensorConstant{1} of None
# optimizer: rewrite local_subtensor_merge replaces Subtensor{}.0 of Subtensor{}(Subtensor{int64}.0) with Subtensor{int64}.0 of Subtensor{int64}(k{(1,) of 0}, ScalarConstant{0})
# optimizer: rewrite local_subtensor_remove_broadcastable_index replaces Subtensor{int64}.0 of Subtensor{int64}(k{(1,) of 0}, ScalarConstant{0}) with InplaceDimShuffle{}.0 of InplaceDimShuffle{}(k{(1,) of 0})
# optimizer: rewrite local_add_canonizer replaces Elemwise{sub,no_inplace}.0 of Elemwise{sub,no_inplace}(TensorConstant{4}, Elemwise{switch,no_inplace}.0) with Elemwise{sub,no_inplace}.0 of Elemwise{sub,no_inplace}(TensorConstant{4}, Elemwise{switch,no_inplace}.0)
# optimizer: rewrite constant_folding replaces Elemwise{lt,no_inplace}.0 of Elemwise{lt,no_inplace}(TensorConstant{1}, TensorConstant{4}) with TensorConstant{True} of None
# optimizer: rewrite constant_folding replaces Elemwise{switch,no_inplace}.0 of Elemwise{switch,no_inplace}(TensorConstant{True}, TensorConstant{1}, TensorConstant{4}) with TensorConstant{1} of None
# optimizer: rewrite constant_folding replaces Elemwise{sub,no_inplace}.0 of Elemwise{sub,no_inplace}(TensorConstant{4}, TensorConstant{1}) with TensorConstant{3} of None
# optimizer: rewrite constant_folding replaces InplaceDimShuffle{}.0 of InplaceDimShuffle{}(k{(1,) of 0}) with TensorConstant{0} of None
As we can see, the rewrite push_out_non_seq_scan
is the one removing the Scan
Op
altogether. It produces something like Alloc(Subtensor{int64}.0, TensorConstant{4})
, which appears to have the correct shape (i.e. 4).
One of the other rewrites seems be the problem.
No, the problem is somewhere else, because push_out_non_seq_scan
is replacing a Scan
that's being sliced like scan_out[1:]
with a vector of length four, which results in a vector of length three.
It looks like the graph that aesara.scan
is constructing is incorrect, because it shouldn't be slicing/Subtensor
ing anything. There are no taps, and that's when Subtensor
s are added (e.g. to return only the computed iteration results and not the also the initial tap values).
The example output in #945 says that aesara.scan
is interpreting the output_info
as a single sit-sot (i.e. "single-input, single-output") input, and sit-sot outputs are always sliced like out[1:]
here.
It seems like we need that input to be seen as a nit-sot (i.e. "no-input, single-output").
This block returns only 3 values instead of the expected 4
Other combinations return the correct outputs