aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 155 forks source link

Make `Scan` convert `Scalar` `Type`s to `TensorType`s #512

Open ricardoV94 opened 3 years ago

ricardoV94 commented 3 years ago

I have been trying to get Scan to work within a gradient expression without success. I wouldn't be surprised if I am using Scan incorrectly, so let me know :)

def add_five(a):
    return a + 5

def add_five_scan(a):    
    def step(last_count):
        return last_count + 1.0

    counter, _ = aesara.scan(
        fn=step,
        sequences=None,
        outputs_info=[a],
        n_steps=5,
    )

    return counter[-1]

a = at.scalar('a')
y = add_five(a)
print(y.eval({a: 5}), y.type, y.broadcastable)
# 10.0 TensorType(float64, scalar) ()

a = at.scalar('a')
y = add_five_scan(a)
print(y.eval({a: 5}), y.type, y.broadcastable)
# 10.0 TensorType(float64, scalar) ()
class UnaryOp(ScalarOp):
    nin = 1

    def grad(self, inp, grads):
        (x,) = inp
        (gz,) = grads
        return [gz * add_five(x)]

unary_op = Elemwise(
    UnaryOp(upgrade_to_float_no_complex, "unary_op"),
    name="Elemwise{unary_op,no_inplace}"
)

class UnaryOpScan(ScalarOp):
    nin = 1

    def grad(self, inp, grads):
        (x,) = inp
        (gz,) = grads
        return [gz * add_five_scan(x)]

unary_op_scan = Elemwise(
    UnaryOpScan(upgrade_to_float_no_complex, "unary_op_scan"),
    name="Elemwise{unary_op_scan,no_inplace}"
)
x = at.scalar('x')
out = unary_op(x)
grad = aesara.grad(out, x)
print(grad.eval({x: 3}))
# 8.0

x = at.scalar('x')
out = unary_op_scan(x)
grad = aesara.grad(out, x)  # <-- Raises AttributeError
grad.eval({x: 3})
---------------------------------------------------------------------------

AttributeError                            Traceback (most recent call last)

<ipython-input-16-e2874236308b> in <module>()
      1 x = at.scalar('x')
      2 out = unary_op_scan(x)
----> 3 grad = aesara.grad(out, x)
      4 grad.eval({x: 3})

11 frames

/usr/local/lib/python3.7/dist-packages/aesara/gradient.py in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    628             assert g.type.dtype in aesara.tensor.type.float_dtypes
    629 
--> 630     rval = _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
    631 
    632     for i in range(len(rval)):

/usr/local/lib/python3.7/dist-packages/aesara/gradient.py in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
   1439         return grad_dict[var]
   1440 
-> 1441     rval = [access_grad_cache(elem) for elem in wrt]
   1442 
   1443     return rval

/usr/local/lib/python3.7/dist-packages/aesara/gradient.py in <listcomp>(.0)
   1439         return grad_dict[var]
   1440 
-> 1441     rval = [access_grad_cache(elem) for elem in wrt]
   1442 
   1443     return rval

/usr/local/lib/python3.7/dist-packages/aesara/gradient.py in access_grad_cache(var)
   1392                     for idx in node_to_idx[node]:
   1393 
-> 1394                         term = access_term_cache(node)[idx]
   1395 
   1396                         if not isinstance(term, Variable):

/usr/local/lib/python3.7/dist-packages/aesara/gradient.py in access_term_cache(node)
   1219                             )
   1220 
-> 1221                 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
   1222 
   1223                 if input_grads is None:

/usr/local/lib/python3.7/dist-packages/aesara/tensor/elemwise.py in L_op(self, inputs, outs, ograds)
    549 
    550         # Compute grad with respect to broadcasted input
--> 551         rval = self._bgrad(inputs, outs, ograds)
    552 
    553         # TODO: make sure that zeros are clearly identifiable

/usr/local/lib/python3.7/dist-packages/aesara/tensor/elemwise.py in _bgrad(self, inputs, outputs, ograds)
    608             ).outputs
    609             scalar_igrads = self.scalar_op.L_op(
--> 610                 scalar_inputs, scalar_outputs, scalar_ograds
    611             )
    612             for igrad in scalar_igrads:

/usr/local/lib/python3.7/dist-packages/aesara/scalar/basic.py in L_op(self, inputs, outputs, output_gradients)
   1141 
   1142     def L_op(self, inputs, outputs, output_gradients):
-> 1143         return self.grad(inputs, output_gradients)
   1144 
   1145     def __eq__(self, other):

<ipython-input-14-e66194d4e28c> in grad(self, inp, grads)
     18         (x,) = inp
     19         (gz,) = grads
---> 20         return [gz * add_five_scan(x)]
     21 
     22 unary_op_scan = Elemwise(

<ipython-input-3-34c2b27606c2> in add_five_scan(a)
     10         sequences=None,
     11         outputs_info=[a],
---> 12         n_steps=5,
     13     )
     14 

/usr/local/lib/python3.7/dist-packages/aesara/scan/basic.py in scan(fn, sequences, outputs_info, non_sequences, n_steps, truncate_gradient, go_backwards, mode, name, profile, allow_gc, strict, return_list)
   1058     info["strict"] = strict
   1059 
-> 1060     local_op = Scan(inner_inputs, new_outs, info)
   1061 
   1062     ##

/usr/local/lib/python3.7/dist-packages/aesara/scan/op.py in __init__(self, inputs, outputs, info, typeConstructor)
    177             self.output_types.append(
    178                 typeConstructor(
--> 179                     broadcastable=(False,) + o.type.broadcastable, dtype=o.type.dtype
    180                 )
    181             )

AttributeError: 'Scalar' object has no attribute 'broadcastable'
brandonwillard commented 3 years ago

I'm guessing that this is a real bug in Scan, because Scalars don't have a broadcastable attribute and it looks like the scan code is implicitly requiring that.

At the very least, Scan/scan should not accept Types it doesn't actually support.

brandonwillard commented 3 years ago

Also, don't forget to make your MWE code runnable; there are a few imports missing.

brandonwillard commented 3 years ago

Here's a complete MWE that illustrates the Scan-only scalar input situation more directly:

import numpy as np

import aesara
import aesara.tensor as at

def add_five_scan(a):
    def step(last_count):
        return last_count + 1.0

    counter, _ = aesara.scan(
        fn=step,
        sequences=None,
        outputs_info=[a],
        n_steps=5,
    )

    return counter[-1]

a = at.scalar('a')
y = add_five_scan(a)
print(y.eval({a: 5}), y.type, y.broadcastable)
# 10.0 TensorType(float64, scalar) ()

from aesara.scalar import float64
a = float64('a')
y = add_five_scan(a)

The problem is that the first a is actually a TensorType "scalar" and the second is a Scalar Type "scalar" (i.e. these are actually two different Types).

ricardoV94 commented 3 years ago

Is there a reason why scan can't / shouldn't be made to work with these scalar types?

brandonwillard commented 3 years ago

No, it definitely should work with Scalar Types (e.g. using TensorFromScalar).

ricardoV94 commented 2 years ago

I investigated this a little bit further. It is pretty straightforward to allow Scan to accept pure scalars by changing this line:

https://github.com/aesara-devs/aesara/blob/7c4871cee2dedd1779935dbeef344e795f3045cd/aesara/scan/basic.py#L596

to:

if not isinstance(actual_arg, TensorVariable):

However this is not of much help when it comes to having Scans inside a gradient expression. The bigger problem is that the Elemwise gradient expects the gradient graph to be entirely scalar, and then tries to recursively convert it to a "broadcastable" tensor version, which does not really make sense for scan graphs...

https://github.com/aesara-devs/aesara/blob/6f6857997f7047d4e770e8869e81ffe802db943a/aesara/tensor/elemwise.py#L622-L653

I tried a couple of hacks to accommodate Scan graphs (such as manually bypassing ScalarFromTensors, and Rebroadcasts) as well as not trying to convert nodes that are already Elemwise... but couldn't find anything that worked.

Perhaps we would need a more bare-bones scalar Scan, that can be safely "Elemwised" for these situations? I have no idea if that makes sense...

brandonwillard commented 2 years ago

Yes, there are two issues in the original example: 1) Elemwise.[R|L]_op don't work with Scalar Ops that have gradients that aren't comprised of exclusively scalar Ops (this is largely due to the implementation of Elemwise._bgrad), and 2) Scan doesn't accept or return Scalar typed arguments.

For your example, the latter is easily fixed with something like

class UnaryOpScan(ScalarOp):
    nin = 1

    def grad(self, inp, grads):
        (x,) = inp
        (gz,) = grads
        x_at = at.tensor_from_scalar(x)
        add_5_res = at.scalar_from_tensor(add_five_scan(x_at))
        return [gz * add_5_res]

We could update Scan so that it handles Scalar-typed inputs, but that's a minor convenience.

The real problem appears to have little to do with Scan, so either this issue needs to be updated, or a new one needs to be opened for the Elemwise._bgrad issue.