aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.17k stars 156 forks source link

Commutative Elemwise Operations are not merged #958

Open ricardoV94 opened 2 years ago

ricardoV94 commented 2 years ago
import aesara
import aesara.tensor as at

x, y = at.scalars("x", "y")
f = aesara.function([x, y], [x + y, x + y, y + x, y + x])
aesara.dprint(f)
DeepCopyOp [id A] ''   1
 |Elemwise{add,no_inplace} [id B] ''   0
   |x [id C]
   |y [id D]
Elemwise{add,no_inplace} [id B] ''   0
DeepCopyOp [id E] ''   3
 |Elemwise{add,no_inplace} [id F] ''   2
   |y [id D]
   |x [id C]
Elemwise{add,no_inplace} [id F] ''   2
brandonwillard commented 2 years ago

This might require a simple/naive canonicalization that involves string name/auto-ID ordering of inputs to associative/commutative (AC) Ops. Such a canonicalization should be consistent across recreations of equivalent graphs, so I'm not sure if a direct auto-ID approach is actually sound. It might be, though.

All I know for sure is that we don't want to perform anything like AC matching/unification numerous times.

ricardoV94 commented 2 years ago

One more direct way would be to include op information about what inputs are commutative and in the merge optimization we compare those via set equality instead of order equality like we do now for all inputs.

Just some care would be needed for repeated inputs, so maybe via a Counter instead of set

brandonwillard commented 2 years ago

One more direct way would be to include op information about what inputs are commutative and in the merge optimization we compare those via set equality instead of order equality like we do now for all inputs.

Just some care would be needed for repeated inputs, so maybe via a Counter instead of set

Yes, that's basically how the permutation goals and AC unification work in kanren. The issue I was alluding to involves the need to construct those sets/Counters multiple times.