pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
300 stars 91 forks source link

Simplify dots with 1 #638

Open ricardoV94 opened 5 months ago

ricardoV94 commented 5 months ago

Description

We have a local_0_dot_x that removes useless dots with zero'd inputs. We don't seem to have anything for dots with ones as reported in https://github.com/pymc-devs/pytensor/discussions/637#discussioncomment-8405862

import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode

x = tn.col('x')
f = x @ [[1.]]
with pytensor.config.change_flags(optimizer_verbose=True):
    fn = pytensor.function([x], f, mode=get_default_mode().excluding("BlasOpt"))

pytensor.dprint(fn)
dot [id A] 0
 ├─ x [id B]
 └─ [[1.]] [id C]

I excluded the BlasOpt just to have a simpler graph, but it will still not rewrite it away with those, just add the more complex Blas Op.

https://github.com/pymc-devs/pytensor/blob/d3dd34e7ea78eb1f125dd771d06436d49bf2ce5d/pytensor/tensor/rewriting/math.py#L155-L190

Dhruvanshu-Joshi commented 2 months ago

Looks like an interesting issue. We'd just have to replace 0 with x in the local_0_dot_x right? Here's what I have in mind:

 @register_canonicalize 
 @register_stabilize 
 @node_rewriter([Dot]) 
 def local_1_dot_x(fgraph, node): 
     if not isinstance(node.op, Dot): 
         return False 

     x = node.inputs[0] 
     y = node.inputs[1] 
     replace = False 
     try: 
         if get_underlying_scalar_constant_value(x, only_process_constants=True) == 1: 
             replace = True 
             var = y
     except NotScalarConstantError: 
         pass 

     try: 
         if get_underlying_scalar_constant_value(y, only_process_constants=True) == 1: 
             replace = True 
             var=x
     except NotScalarConstantError: 
         pass 

     if replace: 
         constant_value = constant(get_underlying_scalar_constant_value(var, only_process_constants=True), dtype=node.outputs[0].type.dtype) 
         if x.ndim == 2 and y.ndim == 2: 
             constant_value = assert_op(constant_value, eq(x.shape[1], y.shape[0])) 
             return [alloc(constant_value, x.shape[0], y.shape[1])] 
         elif x.ndim == 1 and y.ndim == 2: 
             constant_value = assert_op(constant_value, eq(x.shape[0], y.shape[0])) 
             return [alloc(constant_value, y.shape[1])] 
         elif x.ndim == 2 and y.ndim == 1: 
             constant_value = assert_op(constant_value, eq(x.shape[1], y.shape[0])) 
             return [alloc(constant_value, x.shape[0])] 
         elif x.ndim == 1 and y.ndim == 1: 
             constant_value = assert_op(constant_value, eq(x.shape[0], y.shape[0])) 
             return [constant_value] 

However, I think using constant value might be wrong here. Will I have to replace with the entire var itself? If yes, then is this the correct way of moving forward?

var=assert_op(var,  eq(...)
alloc(var, shape)
ricardoV94 commented 2 months ago

No, the rule is slightly different for ones, as it consists of summing the left matrix. Also have to reason about broadcasting.

I suggest playing with numpy to get a feel of what it should do.

Dhruvanshu-Joshi commented 2 months ago

Ohk. Just so that I get it correctly, for a given graph say

Sub [id A]
 ├─ dot [id B]
 │  ├─ dot [id C]
 │  │  ├─ Transpose{axes=[1, 0]} [id D] 'A.T'
 │  │  │  └─ A [id E]
 │  │  └─ Neg [id F]
 │  │     └─ x [id G]
 │  └─ [[1.]] [id H]
 └─ dot [id I]
    ├─ A [id E]
    └─ dot [id J]
       ├─ x [id G]
       └─ [[1.]] [id H]

we want the output of the rewrite to be:

Sub [id A]
 ├─ dot [id B]
 │  ├─ Transpose{axes=[1, 0]} [id C] 'A.T'
 │  │  └─ A [id D]
 │  └─ Neg [id E]
 │     └─ x [id F]
 └─ dot [id G]
    ├─ A [id D]
    └─ x [id F]

Is this correct? And if yes, how does summing of left matrices and broadcasting come into picture here?