pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
297 stars 91 forks source link

Softmax fails with integer dtypes only at runtime #857

Open ricardoV94 opened 1 week ago

ricardoV94 commented 1 week ago

Description

Brought up in #846

import pytensor
import pytensor.tensor as pt

x = pt.vector("x", dtype="int64")
out = pt.special.softmax(x)

# Doesn't seem right
out.dprint(print_type=True)
# Softmax{axis=None} [id A] <Vector(int64, shape=(?,))>
# └─ x [id B] <Vector(int64, shape=(?,))>

# No complaints
fn = pytensor.function([x], out)

fn([1, 2, 3])  # TypeError: not a float

We should either raise at graph definition time, or cast the input to float. Scipy is happy to take integers (and return floats), so we could try to do the same.

ricardoV94 commented 1 week ago

This problem will go away if we use OpFromGraph to represent the Softmax, as exp(integers) is well defined for those Operations