pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.66k stars 2k forks source link

Avoid casting all terms to the same dtype in logp #7329

Closed ricardoV94 closed 4 months ago

ricardoV94 commented 4 months ago

Description

This change avoids many warnings when compiling a logp to JAX in pytensor.config.floatX = "float32". Since the terms are all scalars, it should also be more efficient, because it avoids an explict MakeVector.

Related Issue

Checklist

Type of change


πŸ“š Documentation preview πŸ“š: https://pymc--7329.org.readthedocs.build/en/7329/

ricardoV94 commented 4 months ago

Why does this change affect int casting? Aren’t the logp terms already floatX?

This change avoids explicit castings to float64. Some logps may be float32 and others float64, and if you sum there will be an explicit casting + make vector. If you do add there is only implicit casting and jax doesn't emit a warning about ignoring explicit float64 casts.

This has nothing to do with ints anymore, I have split those changes into another PR

ricardoV94 commented 4 months ago

Not a PyTensor error, but a cryptic failure when add has no entries

ricardoV94 commented 4 months ago

Actually we have a rewrite that does this, but it's a bit conservative with dtypes... Maybe we can tweak this rewrite instead.

https://github.com/pymc-devs/pytensor/blob/f7b0a7a48b929605a083e13a12f144040a7fe265/pytensor/tensor/rewriting/basic.py#L919-L956