Open brandonwillard opened 4 years ago
Here's a MWE that demonstrates the problem:
import numpy as np
import theano
import theano.tensor as tt
from theano.compile import optdb
from theano.printing import debugprint as tt_dprint
from theano.tensor.opt import local_add_mul_fusion, local_neg_to_mul
theano.config.cxx = ""
mu = tt.vector("mu")
mu.tag.test_value = np.r_[0.0, 0.0]
tau = tt.vector("tau")
tau.tag.test_value = np.r_[1.0, 1.0]
fgraph = theano.gof.FunctionGraph([mu, tau], [-tau * mu])
>>> tt_dprint(fgraph)
Elemwise{mul,no_inplace} [id A] '' 1
|Elemwise{neg,no_inplace} [id B] '' 0
| |tau [id C]
|mu [id D]
optimizer = theano.gof.opt.EquilibriumOptimizer([
local_neg_to_mul,
tt.opt.FusionOptimizer(local_add_mul_fusion),
], max_use_ratio=100)
optimizer.optimize(fgraph)
>>> tt_dprint(fgraph)
Elemwise{mul,no_inplace} [id A] '' 1
|InplaceDimShuffle{x} [id B] '' 0
| |TensorConstant{-1.0} [id C]
|tau [id D]
|mu [id E]
This optimized graph is invalid, because it applies three arguments to the binary Mul
. It should use Product
as the Op
within the Elemwise
—if it should use Elemwise
at all.
From this example, it's clear that local_add_mul_fusion
introduces the problem, and it probably does the same for Add
(i.e. when it should be Sum
).
The scalar multiplication
Op
,Mul
, has animpl
method that actually usesnp.product
. When used in conjunction withElemwise
, optimizations likelocal_mul_canonizer
construct graphs that essentially haveElemwise(Mul)(a, b, c, ...)
nodes.In other words, it adds nodes to the graph that violate the arity of scalar functions (e.g. by turning the binary operator
Mul
into a variadic operator) and puts the graph into a representationally invalid state.As well,
Elemwise
implements numerous unnecessarily complicated hacks throughout its numerical evaluation stages in order to make these invalid graphs work (e.g. vectorizations of arity-violatingimpl
functions occurring inOp.prepare_node
during calls toOp.perform
?!).If anything, to preserve the validity of the graph, any optimization that flatten these should also replace the scalar
Op
s with arity-appropriate ones (e.g. replaceMul
withProduct
).