Open ricardoV94 opened 3 years ago
I'm guessing that this is a real bug in Scan
, because Scalar
s don't have a broadcastable
attribute and it looks like the scan code is implicitly requiring that.
At the very least, Scan
/scan
should not accept Type
s it doesn't actually support.
Also, don't forget to make your MWE code runnable; there are a few imports missing.
Here's a complete MWE that illustrates the Scan
-only scalar input situation more directly:
import numpy as np
import aesara
import aesara.tensor as at
def add_five_scan(a):
def step(last_count):
return last_count + 1.0
counter, _ = aesara.scan(
fn=step,
sequences=None,
outputs_info=[a],
n_steps=5,
)
return counter[-1]
a = at.scalar('a')
y = add_five_scan(a)
print(y.eval({a: 5}), y.type, y.broadcastable)
# 10.0 TensorType(float64, scalar) ()
from aesara.scalar import float64
a = float64('a')
y = add_five_scan(a)
The problem is that the first a
is actually a TensorType
"scalar" and the second is a Scalar
Type
"scalar" (i.e. these are actually two different Type
s).
Is there a reason why scan can't / shouldn't be made to work with these scalar types?
No, it definitely should work with Scalar
Type
s (e.g. using TensorFromScalar
).
I investigated this a little bit further. It is pretty straightforward to allow Scan to accept pure scalars by changing this line:
to:
if not isinstance(actual_arg, TensorVariable):
However this is not of much help when it comes to having Scans inside a gradient expression. The bigger problem is that the Elemwise gradient expects the gradient graph to be entirely scalar, and then tries to recursively convert it to a "broadcastable" tensor version, which does not really make sense for scan graphs...
I tried a couple of hacks to accommodate Scan graphs (such as manually bypassing ScalarFromTensors, and Rebroadcasts) as well as not trying to convert nodes that are already Elemwise... but couldn't find anything that worked.
Perhaps we would need a more bare-bones scalar Scan, that can be safely "Elemwised" for these situations? I have no idea if that makes sense...
Yes, there are two issues in the original example:
1) Elemwise.[R|L]_op
don't work with Scalar
Op
s that have gradients that aren't comprised of exclusively scalar Op
s (this is largely due to the implementation of Elemwise._bgrad
), and
2) Scan
doesn't accept or return Scalar
typed arguments.
For your example, the latter is easily fixed with something like
class UnaryOpScan(ScalarOp):
nin = 1
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
x_at = at.tensor_from_scalar(x)
add_5_res = at.scalar_from_tensor(add_five_scan(x_at))
return [gz * add_5_res]
We could update Scan
so that it handles Scalar
-typed inputs, but that's a minor convenience.
The real problem appears to have little to do with Scan
, so either this issue needs to be updated, or a new one needs to be opened for the Elemwise._bgrad
issue.
I have been trying to get
Scan
to work within a gradient expression without success. I wouldn't be surprised if I am usingScan
incorrectly, so let me know :)