pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
77 stars 49 forks source link

Autoreparam Exponential distribution #364

Closed ferrine closed 2 months ago

ferrine commented 2 months ago

Came up with reparametrizing Exponential distribution to interpolate between

$$ \text{Exponential}(c) \sim \text{Exponential}(1) / c $$

import pytensor
from pytensor.graph import Apply, Variable
from pymc.logprob.transforms import Transform
from pymc.model.fgraph import (
    ModelDeterministic,
    ModelNamed,
    fgraph_from_model,
    model_deterministic,
    model_free_rv,
    model_from_fgraph,
    model_named,
)
from typing import List, Optional
from pymc_experimental.model.transforms.autoreparam import _vip_reparam_node

@_vip_reparam_node.register
def _(
    op: pm.Exponential,
    node: Apply,
    name: str,
    dims: List[Variable],
    transform: Optional[Transform],
    lam: pt.TensorVariable,
) -> ModelDeterministic:
    rng, size, scale = node.inputs
    scale_centered = scale ** lam
    scale_noncentered = scale ** (1-lam)
    vip_rv_ = pm.Exponential.dist(
        scale=scale_centered,
        size=size,
        rng=rng,
    )
    vip_rv_value_ = vip_rv_.clone()
    vip_rv_.name = f"{name}::tau_"
    if transform is not None:
        vip_rv_value_.name = f"{vip_rv_.name}_{transform.name}__"
    else:
        vip_rv_value_.name = vip_rv_.name
    vip_rv = model_free_rv(
        vip_rv_,
        vip_rv_value_,
        transform,
        *dims,
    )

    vip_rep_ = scale_noncentered * vip_rv

    vip_rep_.name = name

    vip_rep = model_deterministic(vip_rep_, *dims)
    return vip_rep