Closed ricardoV94 closed 4 months ago
Why does this change affect int casting? Arenβt the logp terms already floatX?
This change avoids explicit castings to float64. Some logps may be float32 and others float64, and if you sum there will be an explicit casting + make vector. If you do add there is only implicit casting and jax doesn't emit a warning about ignoring explicit float64 casts.
This has nothing to do with ints anymore, I have split those changes into another PR
Not a PyTensor error, but a cryptic failure when add has no entries
Actually we have a rewrite that does this, but it's a bit conservative with dtypes... Maybe we can tweak this rewrite instead.
Description
This change avoids many warnings when compiling a logp to JAX in
pytensor.config.floatX = "float32"
. Since the terms are all scalars, it should also be more efficient, because it avoids an explictMakeVector
.Related Issue
Checklist
Type of change
π Documentation preview π: https://pymc--7329.org.readthedocs.build/en/7329/