aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 153 forks source link

`Subtensor` Ops with step value `ScalarConstant{1}` do not get simplified #1257

Open rlouf opened 2 years ago

rlouf commented 2 years ago

Working on #1202 I noticed that Subtensor Ops with a step index equal to 1 do not get simplified:

import aesara
import aesara.tensor as at

a_at = at.dvector("a")
res, _ = aesara.scan(
    lambda a: 2 * a,
    sequences = [a_at],
    outputs_info = [{}]
)

fn = aesara.function((a_at,), res)
aesara.dprint(fn)
# Elemwise{mul,no_inplace} [id A] 5
#  |TensorConstant{(1,) of 2.0} [id B]
#  |Subtensor{int64:int64:int8} [id C] 4
#    |a [id D]
#    |ScalarFromTensor [id E] 3
#    | |Elemwise{Composite{Switch(LE(i0, i1), i1, i2)}}[(0, 0)] [id F] 2
#    |   |Shape_i{0} [id G] 0
#    |   | |a [id D]
#    |   |TensorConstant{0} [id H]
#    |   |TensorConstant{0} [id I]
#    |ScalarFromTensor [id J] 1
#    | |Shape_i{0} [id G] 0
#    |ScalarConstant{1} [id K]

This node is created by the local_subtensor_merge rewrite:

fn.vm.fgraph.toposort()[4].tag
# scratchpad{'imported_by': ['local_subtensor_merge', 'init']}

The Elemwise{mul} Op is also doing wasteful computation, but I do not quite understand what sequence of rewrites led to this:

fn.vm.fgraph.toposort()[4].tag
# scratchpad{'imported_by': ['local_mul_canonizer', "('Revert', 'inplace_elemwise_optimizer')", 'init'], 'removed_by': ['inplace_elemwise_optimizer'], 'fake_node': mul(<float64>, <float64>)}
rlouf commented 2 years ago

More generally, many issues with the JAX backend could be resolved with rewrites to simplify slicing operations in the graph and constant-folding for slice elements.

ricardoV94 commented 2 years ago

Possibly related to https://github.com/aesara-devs/aesara/issues/940