aesara-devs / aeppl

Tools for an Aesara-based PPL.
https://aeppl.readthedocs.io
MIT License
64 stars 20 forks source link

Type mismatch when computing the `joint_logprob` of a mixture model #214

Closed rlouf closed 1 year ago

rlouf commented 1 year ago

I was working on a zero-inflated model example for the documentation when I stumbled upon the following error (I added a dtype attribute to DiracDelta)

import aesara
import aesara.tensor as at

import aeppl
from aeppl.dists import DiracDelta

srng = at.random.RandomStream(0)

dirac_delta = DiracDelta(dtype='int64')

p = at.scalar("p")
i_rv = srng.bernoulli(p)
zero_inflated_rv = at.stack([dirac_delta(0), srng.poisson(1.)])[i_rv]

print(dirac_delta(0).type.dtype)
# int64
print(srng.poisson(1.).type.dtype)
# int64

try:
    logprob, (z_vv,) = aeppl.joint_logprob(zero_inflated_rv)
except Exception as e:
    print(e)
    # IfElse requires compatible dtypes for both branches: got true_branch=float32, false_branch=int64

The problem comes from this line:

logp_val += ifelse(
    at.eq(indices[0], i),
    comp_logp,
    at.zeros_like(value),
)

It looks like there might be a problem of the dtype not being properly propagated.

brandonwillard commented 1 year ago

The problem comes from this line:

logp_val += ifelse(
    at.eq(indices[0], i),
    comp_logp,
    at.zeros_like(value),
)

~Ah, yeah, looks like we need to add a dtype to the at.zeros_like.~ Actually, both branches should be floats, no?

brandonwillard commented 1 year ago

I was working on a zero-inflated model example for the documentation when I stumbled upon the following error (I added a dtype attribute to DiracDelta)

How did you add a dtype attribute to DiracDelta? The way it currently works should adequately reflect the dtype of the input, so it seems like the input should be cast to the desired dtype instead.