aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 155 forks source link

Consider adding `local_div_to_reciprocal` to canonicalizations #441

Open brandonwillard opened 3 years ago

brandonwillard commented 3 years ago

Currently, expressions like 1 / (1 + x), which are expressed using a true_div Op, aren't converted to reciprocal(1 + x) during canonicalization. This implies that—among other things—we don't consider the two graphs equivalent in some sense, but they very much are.

There is an existing rewrite that takes care of this (i.e. local_div_to_reciprocal), but it's under the "specialize" set of optimizations.

import aesara
import aesara.tensor as at

from aesara.graph.opt_utils import optimize_graph

x = at.vector("x")
y = 1/(1 + x)

z = optimize_graph(y, include=["canonicalization"], clone=True)

aesara.dprint(z)
# Elemwise{true_div,no_inplace} [id A] ''   
#  |InplaceDimShuffle{x} [id B] ''   
#  | |TensorConstant{1} [id C]
#  |Elemwise{add,no_inplace} [id D] ''   
#    |InplaceDimShuffle{x} [id E] ''   
#    | |TensorConstant{1} [id F]
#    |x [id G]

z = optimize_graph(y, include=["canonicalization", "specialize"], clone=True)

aesara.dprint(z)
# Elemwise{reciprocal,no_inplace} [id A] ''   
#  |Elemwise{add,no_inplace} [id B] ''   
#    |TensorConstant{(1,) of 1} [id C]
#    |x [id D]

If we made local_div_to_reciprocal a canonicalization rewrite, then we could, for instance, remove local_reciprocal_1_plus_exp (after replacing true_div with reciprocal in a few place), because it would be redundant. The same is true if we decided to go with true_div as the canonical form, instead of reciprocal; either way, we should only have one form for these graphs.

ricardoV94 commented 3 years ago

I think the more general true_div should be the canonical form even though most division rewrites probably apply only to the more restricted 1/x cases...

Unless, of course, it complicates things too much (e.g., due to other less important rewrites that may affect the numerator)

ricardoV94 commented 3 years ago

Just found out that reciprocal is canonicalized to power(x, -1) here: https://github.com/aesara-devs/aesara/blob/57a1eb72b89ade64920fe5b7196062792f2d6e8a/aesara/tensor/math_opt.py#L1911-L1916