Open ricardoV94 opened 2 years ago
They serve the same goal the "standard" Distribution class is doing right now: basically manage all non distribution kwargs: shape/dims, transformed, observed; and behind the scenes magic things like value variables and rngs
The base
Distribution
class was only ever supposed to make the hand-off to theModel
class, which would then handle all the generic tasks (e.g. value variables, RNGs, shape/dims, etc.), and the rest was just for backward compatibility (e.g. the now completely unnecessaryDisribution.dist
interface). Notice how instances of these classes are never made, and that only static/class/type-level functions are used.By extending
Distribution
and adding more logic to those classes, it could become considerably more difficult to unwind those temporary backward compatibility-only choices, and they'll eventually become a permanent and confusing part of v4's design.This isn't the time or place to start addressing all that, but these are the kinds of design considerations that need to go alongside changes/additions to the relevant areas of code (i.e.
Distribution
-related code).
Originally posted by @brandonwillard in https://github.com/pymc-devs/pymc/issues/5169#issuecomment-1004289756
Do we have a spec in mind for a replacement class? Maybe an explicit list of what needs to be removed and added with respect to the current class would be a good place to start.
Do we have a spec in mind for a replacement class? Maybe an explicit list of what needs to be removed and added with respect to the current class would be a good place to start.
I will update the issue tomorrow with those
to accommodate V3 limitations no longer hold. If youre easily able to list these no longer valid limitations while writing the doc I would also appreciate that (to help with my personal understanding). No obligation if itll take you a lot of time
Updated the top post to mention every function of Distribution
classes that I could find. Also highlighted some of the points I think will be a bit more challenging, or may need more drastic API changes.
eval()
should be aliased to something more intuitive like sample()
or draw()
where appropriate.
eval()
should be aliased to something more intuitive likesample()
ordraw()
where appropriate.
eval
is just the standard Aesara debug feature that exists for any node. We can have a function wrapper that compiles the "proper" aesara function and takes a given number of draws. For analogy with V3 it could be named pm.random
, but I like the draw
name better.
The question of RNGs with default updates exists anyway, but could arguably be offloaded to pymc.aesaraf.compile_pymc
which already does this for Simulator variables anyway, and which pm.random
/ pm.draw
would call. That would be one less thing that Distribution has to be concerned with.
https://github.com/pymc-devs/pymc/blob/75ea2a80cb27773e93b7b207043077953940e6ff/pymc/aesaraf.py#L968
Yeah, I like a functional approach there, and I agree draw
is better than sample
since the latter is already associated with MCMC sampling.
What is the selling point of it? It's a big big refactoring that forces us into an API breakage. And after adding all those decorators we'd probably end up with more lines of code per distribution; considerably harder to comprehend.
First of all, if we create no instances of any PyMC distribution. Why do they exist at all?
We distinguish between variables that are
pm.MyDistribution("name", some, params)
orpm.MyDistribution.dist(some params)
Pure Aesara RV Op
s don't make this distinction (point 7).
Then PyMC distributions have a much richer API & behavior compared to Aesara RandomVariable
Op
s (points 1,2,3,5).
So we need to wrap the Op
.
But should these adapters be a class MyDistribution(...)
or a def MyDistribution(...)
?
Neither MyDistribution
nor MyDistribution.dist()
seem to get much love, and generally it feels like there's a recent hype for functional API designs.
But class vs. function aside, let's consider just the resulting user API syntax:
pm.Normal("name", 0, 1)
and pm.Normal.dist(0, 1)
→ the current API. Done with class Normal(pm.Distribution)
because otherwise we'd be monkey-patching things onto functions and nobody wants that. We can still rename .dist()
to .cool()
or something.pm.Normal("name", 0, 1)
and pm.Normal(0, 1)
→ doesn't work as long as name, mu, sd
etc. are positional. Same for having a kwarg like pm.Normal(0, 1, register=False)
. In C# this would be trivial, by the way.pm.Normal("name", 0, 1)
and pm.normal(0, 1)
→ Both could be functions, but maybe also easy to confuse?So we could move away from class MyDistribution(Distribution)
only if we switch to syntax 3.
Then every distribution needs to get a def mydistribution(...)
and def MyDistribution(...)
.
The latter could be a one-liner: MyDistribution = pm.as_rich_RV(mydistribution)
.
Then we'd have to dispatch logp
and get_moment
etc. onto the RV Op as @ricardoV94 suggested.
The resulting code would no longer be grouped by distribution, and have a ton of dispatch decorators doing the job that's currently done by DistributionMeta
.
- Taking care of default updates for the RandomState variables so that returned variables "look" random by default.
This once bit me really hard when I tried to demo Aesara to someone.
I wanted to show how conveniently one can do some symbolic calculations and .eval()
things as random numbers.
...took me an hour to understand that I was use to a really awesome behavior that was actually a PyMC, not an Aesara feature.
If Aesara made the non-deterministic .eval()
the default---like with PyMC RVs---it would be a lot more beginner friendly & interesting for people doing, for example Monte Carlo sampling.
pm.Normal("name", 0, 1) and pm.Normal(0, 1) → doesn't work as long as name, mu, sd etc. are positional. Same for having a kwarg like pm.Normal(0, 1, register=False). In C# this would be trivial, by the way.
Can't we just check if the first variable is a string and react accordingly?
Then we'd have to dispatch logp and get_moment etc. onto the RV Op as @ricardoV94 suggested. The resulting code would no longer be grouped by distribution, and have a ton of dispatch decorators doing the job that's currently done by DistributionMeta.
We already do that anyways, just via the cryptic MetaDistribution. Structurally, what difference does it make if the logp and logcdf are fake methods inside a fake classes or real functions one after the other?
Does this seem so bad? https://github.com/aesara-devs/aeppl/blob/333108143a1f4b63d9fcd9842dbe35457025c180/aeppl/logprob.py#L104-130
More importantly users should not be calling these pseudo methods themselves because they expect as inputs the already parsed and symbolic canonical parameters of the distribution. For some (many) distributions these have nothing to do with what they would pass into .dist
, so making those accessible from the distribution classes (via autocompletion) is actually error prone.
Posting this incase it helps... I came across an issue like this one on discourse while trying to use v3 code in v4. Trying to do some of the plots in BCB the authors use FreeRV.distribution.all_trees[..].predict_output() in creating insightful plots. In v4 the error is 'TensorVariable' object has no attribute 'distribution'.
I tried to solve but it seems in v3, model.py used def var() and returned a <>RV pymc object where distribution=dist. However in v4 we have def register_rv() which returns an aesara tensor variable? Are we breaking this functionality in v4? Sorry if it is something obvious I am missing.
@mitch-at-orika yes, there is a breaking change w.r.t. what's returned. In most cases that's not a problem, but it looks like BART
monkey-patched some information onto the TensorVariable
. This is no longer possible with v4
where these calls don't necessarily return the same tensor that was created by the underlying distribution.
@aloctavodia I believe you can link to relevant BART
issues, or even explain how to access the all_trees
object in v4
?
Thanks for the explanation Michael, it is reassuring it was a BART only option, I originally thought I had just missed this functionality of pymc vars until now.
Here is some pseudo-code that might suffice?
from functools import partial, wraps
import pytensor.tensor.random.basic as ptr
from pymc.distributions.continuous import get_tau_sigma
from pymc.pytensorf import convert_observed_data
from pymc.distributions.shape_utils import convert_dims, shape_from_dims
from pymc.model import modelcontext
from pymc.util import UNSET
def handle_shape(ndim_supp=None):
"""Convert the shape argument to size used by PyTensor."""
def inner_decorator(dist):
@wraps(dist)
def inner_func(*args, size=None, shape=None, **kwargs):
if shape is not None and size is not None:
raise ValueError("Cannot pass both size and shape")
if shape is not None:
# If needed, call dist without size to find out ndim_supp
local_ndim_supp = dist(*args).owner.op.ndim_supp if ndim_supp is None else ndim_supp
size = shape if local_ndim_supp == 0 else shape[:-local_ndim_supp]
return dist(*args, size=size, **kwargs)
return inner_func
return inner_decorator
def register_model_rv(dist, rv_type=None):
"""Register a random variable in a model context."""
@wraps(dist)
def inner_func(name, *args, dims=None, transform=UNSET, observed=None, model=None, **kwargs):
if dims is not None:
dims = convert_dims(dims)
if observed is not None:
observed = convert_observed_data(observed)
# The shape of the variable is determined from the following sources:
# size or shape, otherwise dims, otherwise observed.
if kwargs.get("size") is None and kwargs.get("shape") is None:
if dims is not None:
kwargs["shape"] = shape_from_dims(dims, model)
elif observed is not None:
kwargs["shape"] = tuple(observed.shape)
rv = dist(*args, **kwargs)
model = modelcontext(model)
return model.register_rv(rv, name=name, dims=dims, transform=transform, observed=observed)
# Monkey-patch useful attributes
if rv_type is not None:
inner_func.rv_type = rv_type
inner_func.dist = dist
return inner_func
@handle_shape(ndim_supp=0)
def normal_dist(mu=0, sigma=None, tau=None, **kwargs):
_, sigma = get_tau_sigma(sigma=sigma, tau=tau)
return ptr.normal(mu, sigma, **kwargs)
Normal = register_model_rv(normal_dist, rv_type=ptr.NormalRV)
This also make writing distribution helpers simpler. Right now we have to define a redundant __new__
method, to preserve the usual API:
Instead this could be done like this:
from pymc.distributions.mixture import Mixture
def normal_mixture_dist(w, mu, sigma=None, tau=None, **kwargs):
return Mixture.dist(w, Normal.dist(mu, sigma=sigma, tau=tau), **kwargs)
NormalMixture = register_model_rv(normal_mixture_dist)
@ricardoV94 I think I understand this, but just to be certain:
You have written a function that constructs the correct pytensor TensorVariable, and then you have a wrapper class that associates that variable with whatever model context manager this is created within.
Yes, I believe this should work. Here is an example of how you would type register_model_rv
in order to propogate the function signature of normal_dist
onto Normal
.
I think this introduces another potential problem, however. This new Normal
object does not have the .dist
method or any other methods or properties associated with it. Or perhaps that's ok, and this would be a breaking change?
I think this introduces another potential problem, however. This new
Normal
object does not have the.dist
method or any other methods or properties associated with it. Or perhaps that's ok, and this would be a breaking change?
They are being monkey-patched here:
# Monkey-patch useful attributes
if rv_type is not None:
inner_func.rv_type = rv_type
inner_func.dist = dist
return inner_func
I am not sure that's the best approach, but the current fake classes also seem odd. Maybe what's done by Distribution.__new__()
now should be done by Distribution().__call__()
? It's still a pretty useless object with only static methods, but it binds the two methods more transparently
There is no other method that should be attached to the Normal
, including logp
and logcdf
. Those should all be accessed by pm.logp()
and alike, because their signature is an implementation detail (i.e., which canonical parametrization we decide to use) that the user shouldn't need to be aware of. In that regard it's a good thing that something like Normal.logp
will cease to exist.
Here is a non-fake class that does the same:
import pytensor.tensor.random.basic as ptr
from pymc.distributions.continuous import get_tau_sigma
from pymc.pytensorf import convert_observed_data
from pymc.distributions.shape_utils import convert_dims, shape_from_dims
from pymc.model import modelcontext
from pymc.util import UNSET
class Distribution:
rv_type = None
rv_op = None
@classmethod
def dist(cls, *args, size=None, shape=None, **kwargs):
if shape is not None and size is not None:
raise ValueError("Cannot pass both size and shape")
if shape is not None:
ndim_supp = getattr(cls.rv_type, "ndim_supp", None)
if ndim_supp is None:
# If needed, call dist without size to find out ndim_supp
ndim_supp = dist(*args).owner.op.ndim_supp
size = shape if ndim_supp == 0 else shape[:-ndim_supp]
return cls.rv_op(*args, size=size, **kwargs)
def __call__(self, name, *args, dims = None, transform = UNSET, observed = None, model = None, ** kwargs):
if dims is not None:
dims = convert_dims(dims)
if observed is not None:
observed = convert_observed_data(observed)
# The shape of the variable is determined from the sources:
# size or shape, otherwise dims, otherwise observed.
if kwargs.get("size") is None and kwargs.get("shape") is None:
if dims is not None:
kwargs["shape"] = shape_from_dims(dims, model)
elif observed is not None:
kwargs["shape"] = tuple(observed.shape)
rv = self.dist(*args, **kwargs)
model = modelcontext(model)
return model.register_rv(rv, name=name, dims=dims, transform=transform, observed=observed)
class NormalDist(Distribution):
rv_type = ptr.NormalRV
@staticmethod
def rv_op(mu=0, sigma=None, tau=None, **kwargs):
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
return ptr.normal(mu, sigma, **kwargs)
class NormalMixtureDist(Distribution):
# If we subclass from a refactord `Mixture`, this `rv_type` would be obtained automatically
rv_type = Mixture.rv_type
@staticmethod
def rv_op(w, mu, sigma=None, tau=None, **kwargs):
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
return Mixture.dist(w, Normal.dist(mu=mu, sigma=sigma), **kwargs)
Normal = NormalDist()
NormalMixture = NormalMixtureDist()
The sole point of it is that it provides a Normal.dist
method! Nothing else
I believe we desire two interfaces per distribution that require similar (but not identical) signatures:
We want a dist
function, and a "create a distribution and register it with a model" function. Is that a correct interpretation?
For the case of the normal distribution, and only using a minimal number of variables to represent the difference:
def normal_dist(mu, sigma, **kwargs):
...
def Normal(name, mu, sigma, dims, observed, **kwargs):
...
I do not believe there is a way to programmatically produce a static function signature for Normal
by using the one for normal_dist
through such mechanisms as ParamSpec
, as discussed in this gist.
In this case Normal
is a function, but it could easily be a class like it is today, and we could keep .dist
as its method. I like your suggestions for a simpler inheritance structure. But wouldn't it be better to stick the contents of __call__
inside __init__
instead? Then rename NormalDist
to Normal
and call super().__init__()
inside that classes __init__
? Then we only refer to Distribution
and Normal
, rather than Distribution
, NormalDist
and Normal
.
The only solution that I think makes sense is to type the signature out in both cases, using mechanisms like TypedDict
to keep kwargs nice and minimize the amount of code duplication. I'm working on a MVP in a gist to demonstrate.
But wouldn't it be better to stick the contents of call inside init instead?
__init__
is not allowed to return anything other than None. PyMC distributions are returning completely new objects (TensorVariables), and in that sense __new__
as done now is appropriate although an overkill because it's just a fancy function call.
I believe we desire two interfaces per distribution that require similar (but not identical) signatures: We want a dist function, and a "create a distribution and register it with a model" function. Is that a correct interpretation?
Yes that's correct
Ah yes, that makes sense! Thank you for clarifying these things for me!
PyMC distribution classes are weird objects that hold RandomVariables, logp, logcdf and moment methods together (basically doing runtime dispatching) and manage most of the non-RandomVariable kwargs that users are familiar with (observed, transformed, size/dims) and behind the scenes actions like registration in the model.
This exists mostly for backwards compatibility with V3 and ease of developer refactoring, but the current result is far from pretty.
We need to figure out a more elegant/permanent architecture now that many things that existed to accommodate V3 limitations no longer hold.
Distribution
Distribution
is currently performing the following tasks:https://github.com/pymc-devs/pymc/blob/75ea2a80cb27773e93b7b207043077953940e6ff/pymc/distributions/distribution.py#L135
FutureWarnings
fortestval
kwargtau
->sigma
). This is done by the.dist
methods.logp
,logcdf
,random
methods.dist()
API to create an unnamed RV that is not registered in the model. This type of variables is necessary for use in Potentials and other distribution factories that use RVs as building blocks such as Bound and Censored distributions, as well as Mixtures and Timeseries once they get refactored for V4DistributionMeta
In addition we have a
DistributionMeta
that does the following:https://github.com/pymc-devs/pymc/blob/75ea2a80cb27773e93b7b207043077953940e6ff/pymc/distributions/distribution.py#L70
logp
,logcdf
,moment
,default_transform
methods defined in the old PyMC distributions to apply to the respectiverv_op
rv_op
type as subclass of the old style PyMC distribution, so that V3 Discrete/Continuous subclass checks still work?If we want to get rid of
Distribution
we probably need to statically dispatch our methods to the respectiverv_op
. That is nothing special, and is how we do it for aeppl from the get go: https://github.com/aesara-devs/aeppl/blob/38d0c2ea4ecf8505f85317047089ab9999d2f78e/aeppl/logprob.py#L104-L130