aesara-devs / aeppl

Tools for an Aesara-based PPL.
https://aeppl.readthedocs.io
MIT License
64 stars 20 forks source link

Implement logprob for SpecifyShape #171

Open ricardoV94 opened 2 years ago

ricardoV94 commented 2 years ago

Similar to the Dimshuffle case, the best is to let canonicalization move SpecifyShape's out of the way (and improve Aesara if there are obvious missing cases), and only apply a logprob rewrite as the last resort.

The logprob rewrite should be pretty simple

def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs):
  (value,) = values
  # transfer specify_shape from rv to value
  value = at.specify_shape(value, shapes)
  return logprob(inner_rv, value)