aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 154 forks source link

`ScalarFromTensor` of a tensor with shape `()` does not get simplified #1258

Open rlouf opened 2 years ago

rlouf commented 2 years ago

While working on #1202 I found the following:

import aesara
import aesara.tensor as at

a_at = at.scalar("a", dtype="int64")
b_at = at.scalar_from_tensor(a_at)

print(a_at.type)
# TensorType(int64, ())

print(b_at.type)
# int64

fn  = aesara.function((a_at,), b_at)
aesara.dprint(fn)
# ScalarFromTensor [id A] 0
#  |a [id B]

Where I would expect rewrites to remove the useless ScalarFromTensor. This is problematic in the JAX backend because in the following graph the result of a ScalarFromTensor operation on a scalar is fed to a SubTensor Op:

import aesara
import aesara.tensor as at
from aesara.compile.mode import Mode

import numpy as np

a_at = at.vector('a')
res, _ = aesara.scan(
         fn=lambda a: 2 * a,
         sequences=a_at
)

fn  = aesara.function((a_at,), res)

aesara.dprint(fn)
# Elemwise{mul,no_inplace} [id A] 5
#  |TensorConstant{(1,) of 2.0} [id B]
#  |Subtensor{int64:int64:int8} [id C] 4
#    |a [id D]
#    |ScalarFromTensor [id E] 3
#    | |Elemwise{Composite{Switch(LE(i0, i1), i1, i2)}}[(0, 0)] [id F] 2
#    |   |Shape_i{0} [id G] 0
#    |   | |a [id D]
#    |   |TensorConstant{0} [id H]
#    |   |TensorConstant{0} [id I]
#    |ScalarFromTensor [id J] 1
#    | |Shape_i{0} [id G] 0
#    |ScalarConstant{1} [id K]

print(fn.vm.fgraph.toposort()[1].inputs[0].type)
# TensorType(int64, ())

And the ScalarFromTensor is dispatched to:

def scalar_from_tensor(x):
    return jnp.array(x).flatten()[0]

which creates a TracedArray from what what originally a constant, and this TracedArray cannot be used to slice another array. To work around this problem I added some branching logic:

def scalar_from_tensor(x):
    if isinstance(x, int) or isinstance(x, float):
        return x
    else:
        return jnp.array(x).flatten()[0]

which in this case doesn't need to be.

This probably collides with more general questions regarding the representation of scalars in Aesara (and whether we really need to represent scalars with a different type at all), but this could at the very least be a rewrite that is specific to the JAX backend to circumvent the restrictions regarding indexing.

brandonwillard commented 2 years ago

ScalarFromTensor is a Type conversion Op—specifically, from TensorType to ScalarType—so it can't be simplified in that scenario.

Those two types model numpy.ndarrays and numpy.[number|generic]s, respectively, and each has/is its own distinct scalar type. We can interpret the latter (i.e. numpy.[number|generic] types) as Python/native scalars in some cases, which means that ScalarFromTensor can translate to x.item(). We do this in our Numba conversions.

And the ScalarFromTensor is dispatched to:

def scalar_from_tensor(x):
    return jnp.array(x).flatten()[0]

which creates a TracedArray from what what originally a constant, and this TracedArray cannot be used to slice another array. To work around this problem I added some branching logic:

In NumPy those functions should return a numpy.generic type (in other words, a non-ndarray "scalar"). Perhaps this is a JAX choice or bug. Regardless, if it supports .item(), I would consider using that first.

brandonwillard commented 2 years ago

Actually, x[()] is probably the better choice.