Open rlouf opened 2 years ago
ScalarFromTensor
is a Type
conversion Op
—specifically, from TensorType
to ScalarType
—so it can't be simplified in that scenario.
Those two types model numpy.ndarray
s and numpy.[number|generic]
s, respectively, and each has/is its own distinct scalar type. We can interpret the latter (i.e. numpy.[number|generic]
types) as Python/native scalars in some cases, which means that ScalarFromTensor
can translate to x.item()
. We do this in our Numba conversions.
And the
ScalarFromTensor
is dispatched to:def scalar_from_tensor(x): return jnp.array(x).flatten()[0]
which creates a
TracedArray
from what what originally a constant, and thisTracedArray
cannot be used to slice another array. To work around this problem I added some branching logic:
In NumPy those functions should return a numpy.generic
type (in other words, a non-ndarray
"scalar"). Perhaps this is a JAX choice or bug. Regardless, if it supports .item()
, I would consider using that first.
Actually, x[()]
is probably the better choice.
While working on #1202 I found the following:
Where I would expect rewrites to remove the useless
ScalarFromTensor
. This is problematic in the JAX backend because in the following graph the result of aScalarFromTensor
operation on a scalar is fed to aSubTensor
Op:And the
ScalarFromTensor
is dispatched to:which creates a
TracedArray
from what what originally a constant, and thisTracedArray
cannot be used to slice another array. To work around this problem I added some branching logic:which in this case doesn't need to be.
This probably collides with more general questions regarding the representation of scalars in Aesara (and whether we really need to represent scalars with a different type at all), but this could at the very least be a rewrite that is specific to the JAX backend to circumvent the restrictions regarding indexing.