aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 155 forks source link

Prevent `Elemwise` graphs from violating `Op` arities #26

Open brandonwillard opened 4 years ago

brandonwillard commented 4 years ago

The scalar multiplication Op, Mul, has an impl method that actually uses np.product. When used in conjunction with Elemwise, optimizations like local_mul_canonizer construct graphs that essentially have Elemwise(Mul)(a, b, c, ...) nodes.

In other words, it adds nodes to the graph that violate the arity of scalar functions (e.g. by turning the binary operator Mul into a variadic operator) and puts the graph into a representationally invalid state.

As well, Elemwise implements numerous unnecessarily complicated hacks throughout its numerical evaluation stages in order to make these invalid graphs work (e.g. vectorizations of arity-violating impl functions occurring in Op.prepare_node during calls to Op.perform?!).

If anything, to preserve the validity of the graph, any optimization that flatten these should also replace the scalar Ops with arity-appropriate ones (e.g. replace Mul with Product).

brandonwillard commented 4 years ago

Here's a MWE that demonstrates the problem:

import numpy as np

import theano
import theano.tensor as tt

from theano.compile import optdb
from theano.printing import debugprint as tt_dprint

from theano.tensor.opt import local_add_mul_fusion, local_neg_to_mul

theano.config.cxx = ""

mu = tt.vector("mu")
mu.tag.test_value = np.r_[0.0, 0.0]
tau = tt.vector("tau")
tau.tag.test_value = np.r_[1.0, 1.0]

fgraph = theano.gof.FunctionGraph([mu, tau], [-tau * mu])
>>> tt_dprint(fgraph)
Elemwise{mul,no_inplace} [id A] ''   1
 |Elemwise{neg,no_inplace} [id B] ''   0
 | |tau [id C]
 |mu [id D]
optimizer = theano.gof.opt.EquilibriumOptimizer([
    local_neg_to_mul,
    tt.opt.FusionOptimizer(local_add_mul_fusion),
], max_use_ratio=100)

optimizer.optimize(fgraph)
>>> tt_dprint(fgraph)
Elemwise{mul,no_inplace} [id A] ''   1
 |InplaceDimShuffle{x} [id B] ''   0
 | |TensorConstant{-1.0} [id C]
 |tau [id D]
 |mu [id E]

This optimized graph is invalid, because it applies three arguments to the binary Mul. It should use Product as the Op within the Elemwise—if it should use Elemwise at all.

From this example, it's clear that local_add_mul_fusion introduces the problem, and it probably does the same for Add (i.e. when it should be Sum).