pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.53k stars 1.98k forks source link

Implement specialized transformed logp dispatch #7188

Open ricardoV94 opened 4 months ago

ricardoV94 commented 4 months ago

Description

Adds a specialized dispatch for transformed logps

TODO

Related Issue

Checklist

Type of change


πŸ“š Documentation preview πŸ“š: https://pymc--7188.org.readthedocs.build/en/7188/

codecov[bot] commented 4 months ago

Codecov Report

Attention: Patch coverage is 96.72131% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 92.26%. Comparing base (244fb97) to head (af027ad).

Additional details and impacted files [![Impacted file tree graph](https://app.codecov.io/gh/pymc-devs/pymc/pull/7188/graphs/tree.svg?width=650&height=150&src=pr&token=JFuXtOJ4Cb&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs)](https://app.codecov.io/gh/pymc-devs/pymc/pull/7188?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs) ```diff @@ Coverage Diff @@ ## main #7188 +/- ## ======================================= Coverage 92.26% 92.26% ======================================= Files 100 100 Lines 16880 16900 +20 ======================================= + Hits 15574 15593 +19 - Misses 1306 1307 +1 ``` | [Files](https://app.codecov.io/gh/pymc-devs/pymc/pull/7188?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs) | Coverage Ξ” | | |---|---|---| | [pymc/distributions/multivariate.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7188?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9kaXN0cmlidXRpb25zL211bHRpdmFyaWF0ZS5weQ==) | `93.84% <100.00%> (+0.03%)` | :arrow_up: | | [pymc/initial\_point.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7188?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9pbml0aWFsX3BvaW50LnB5) | `100.00% <100.00%> (ΓΈ)` | | | [pymc/logprob/abstract.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7188?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9sb2dwcm9iL2Fic3RyYWN0LnB5) | `96.92% <100.00%> (+1.46%)` | :arrow_up: | | [pymc/logprob/basic.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7188?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9sb2dwcm9iL2Jhc2ljLnB5) | `94.36% <ΓΈ> (-0.04%)` | :arrow_down: | | [pymc/logprob/transform\_value.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7188?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9sb2dwcm9iL3RyYW5zZm9ybV92YWx1ZS5weQ==) | `94.18% <100.00%> (+0.43%)` | :arrow_up: | | [pymc/logprob/transforms.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7188?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9sb2dwcm9iL3RyYW5zZm9ybXMucHk=) | `95.29% <100.00%> (-0.18%)` | :arrow_down: | | [pymc/logprob/utils.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7188?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9sb2dwcm9iL3V0aWxzLnB5) | `97.87% <100.00%> (+0.65%)` | :arrow_up: | | [pymc/model/fgraph.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7188?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9tb2RlbC9mZ3JhcGgucHk=) | `97.39% <100.00%> (ΓΈ)` | | | [pymc/model/transform/conditioning.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7188?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9tb2RlbC90cmFuc2Zvcm0vY29uZGl0aW9uaW5nLnB5) | `95.74% <100.00%> (ΓΈ)` | | | [pymc/distributions/transforms.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7188?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9kaXN0cmlidXRpb25zL3RyYW5zZm9ybXMucHk=) | `97.01% <66.66%> (-1.46%)` | :arrow_down: |
ricardoV94 commented 4 months ago

Is there any expected advantage to this specialization besides avoiding constraints checks (which the user can already do via the model check_bounds flag)?

aseyboldt commented 4 months ago

Nice :D

In the case of the ZeroSumNormal probably not really much, once we figure out the expression in the untransformed space it doesn't change too much, and I guess it is nice to have for instance so that the logp can be computed. Just being able to do it on the transformed space would have saved us doing some math...

I think it is quite common of the logp expressions to be a bit cleaner and possibly numerically more stable on the transformed space.

I guess things like this might also be happining in the dirichlet dists, but looking through that is a bit more work... Nothing here really is a game-changer, but I think it might make quite a few graphs a bit cleaner and a bit more stable.


There is also an additional thing we could do with this that right now we can't do with transformations at all. I'm not 100% sure if we should do this though:

Right now all our transformations are injective, because the whole trick with the jacobian determinant doesn't work otherwise. But if we can compute the logp on the transformed space, we could get rid of that requirement. So for instance we could have a Horseshoe distribution with a transformed space that contains the lambda and the x values, and the tranformation just multiplies those.

Or we could have a transformation that maps a point in 2d space to an angel, and use that in for instance the VanMises distribution to avoid the topology problems we have there right now.

ricardoV94 commented 4 months ago

Thanks for the reply

Re: switches/bound checks, most of those could probably be removed with some domain analysis like exp(x) -> positive so any ge(x, 0) are true, that could be useful beyond PyMC. This could be coupled with user provided type hints (which we are kind of doing with transformed values here) to seed initial info that is not extractable from the graph.

Re: examples, many of those sound like interesting rewrites that PyTensor already does or could do. That need no be a blocker to allow users to work directly on transformed space if they want, but may be a good source of ideas for future rewrites.

The best argument for the functionality in my perspective is the "avoiding hard math", which is more than a nice to have, and could speedup development or allow models that are otherwise impossible to write without a hacky Potential?

For our code-base I think we still want to try and provide constrained space logps as much as possible, since not all we do (or may want to do) is logp-based sampling.

twiecki commented 4 months ago

What does this do?

aseyboldt commented 4 months ago

I just had a look at the generated graphs of each of those. Let's say we want the logp of the model:

with pm.Model(check_bounds=False) as model:
    pm.ZeroSumNormal("x", shape=(10, 10), n_zerosum_axes=2)

The logp + grad with the PR:

>>> func1 = model.logp_dlogp_function(mode="NUMBA")
>>> pytensor.dprint(func1._pytensor_function)
Sum{axes=None} [id A] 1
 └─ Composite{((-0.5 * sqr(i0)) - 0.9189385332046727)} [id B] 'x_zerosum___logprob' 0
    └─ x_zerosum__ [id C]
Neg [id D] 'x_zerosum___grad' 2
 └─ x_zerosum__ [id C]

Inner graphs:

Composite{((-0.5 * sqr(i0)) - 0.9189385332046727)} [id B]
 ← sub [id E] 'o0'
    β”œβ”€ mul [id F]
    β”‚  β”œβ”€ -0.5 [id G]
    β”‚  └─ sqr [id H]
    β”‚     └─ i0 [id I]
    └─ 0.9189385332046727 [id J]

And without:

``` Switch [id A] 'mean(value, axis=n_zerosum_axes) = 0' 31 β”œβ”€ All{axes=None} [id B] 27 β”‚ └─ MakeVector{dtype='bool'} [id C] 24 β”‚ β”œβ”€ All{axes=None} [id D] 20 β”‚ β”‚ └─ Composite{and(le((0.1 * abs(i0)), 1e-09), invert(or(isnan((0.1 * i0)), isinf((0.1 * i0)))))} [id E] 16 β”‚ β”‚ └─ Sum{axis=1} [id F] 11 β”‚ β”‚ └─ Composite{...}.0 [id G] 9 β”‚ β”‚ β”œβ”€ Join [id H] 8 β”‚ β”‚ β”‚ β”œβ”€ 1 [id I] β”‚ β”‚ β”‚ β”œβ”€ Sub [id J] 4 β”‚ β”‚ β”‚ β”‚ β”œβ”€ Join [id K] 3 β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ 0 [id L] β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ x_zerosum__ [id M] β”‚ β”‚ β”‚ β”‚ β”‚ └─ Composite{...}.1 [id N] 2 β”‚ β”‚ β”‚ β”‚ β”‚ └─ ExpandDims{axis=0} [id O] 1 β”‚ β”‚ β”‚ β”‚ β”‚ └─ Sum{axis=0} [id P] 0 β”‚ β”‚ β”‚ β”‚ β”‚ └─ x_zerosum__ [id M] β”‚ β”‚ β”‚ β”‚ └─ Composite{...}.0 [id N] 2 β”‚ β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”‚ └─ Composite{...}.1 [id Q] 7 β”‚ β”‚ β”‚ └─ ExpandDims{axis=1} [id R] 6 β”‚ β”‚ β”‚ └─ Sum{axis=1} [id S] 5 β”‚ β”‚ β”‚ └─ Sub [id J] 4 β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ └─ Composite{...}.0 [id Q] 7 β”‚ β”‚ └─ Β·Β·Β· β”‚ └─ All{axes=None} [id T] 22 β”‚ └─ Composite{and(le((0.1 * abs(i0)), 1e-09), invert(or(isnan((0.1 * i0)), isinf((0.1 * i0)))))} [id U] 18 β”‚ └─ Sum{axis=0} [id V] 12 β”‚ └─ Composite{...}.0 [id G] 9 β”‚ └─ Β·Β·Β· β”œβ”€ Sum{axes=None} [id W] 13 β”‚ └─ Composite{...}.2 [id G] 9 β”‚ └─ Β·Β·Β· └─ -inf [id X] Add [id Y] 'x_zerosum___grad' 34 β”œβ”€ SpecifyShape [id Z] 30 β”‚ β”œβ”€ Split{2}.0 [id BA] 26 β”‚ β”‚ β”œβ”€ Add [id BB] 23 β”‚ β”‚ β”‚ β”œβ”€ SpecifyShape [id BC] 15 β”‚ β”‚ β”‚ β”‚ β”œβ”€ Split{2}.0 [id BD] 10 β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ Composite{...}.1 [id G] 9 β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ 1 [id I] β”‚ β”‚ β”‚ β”‚ β”‚ └─ [9 1] [id BE] β”‚ β”‚ β”‚ β”‚ β”œβ”€ 10 [id BF] β”‚ β”‚ β”‚ β”‚ └─ 9 [id BG] β”‚ β”‚ β”‚ β”œβ”€ Composite{(0.07597469266479578 * (i0 + i1))} [id BH] 21 β”‚ β”‚ β”‚ β”‚ β”œβ”€ SpecifyShape [id BI] 14 β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ Split{2}.1 [id BD] 10 β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ NoneConst{None} [id BJ] β”‚ β”‚ β”‚ β”‚ β”‚ └─ 1 [id I] β”‚ β”‚ β”‚ β”‚ └─ ExpandDims{axis=1} [id BK] 17 β”‚ β”‚ β”‚ β”‚ └─ Sum{axis=1} [id F] 11 β”‚ β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”‚ └─ Mul [id BL] 19 β”‚ β”‚ β”‚ β”œβ”€ [[-0.31622777]] [id BM] β”‚ β”‚ β”‚ └─ SpecifyShape [id BI] 14 β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”œβ”€ 0 [id L] β”‚ β”‚ └─ [9 1] [id BE] β”‚ β”œβ”€ 9 [id BG] β”‚ └─ 9 [id BG] β”œβ”€ Composite{(0.07597469266479578 * (i0 - i1))} [id BN] 32 β”‚ β”œβ”€ SpecifyShape [id BO] 29 β”‚ β”‚ β”œβ”€ Split{2}.1 [id BA] 26 β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”œβ”€ 1 [id I] β”‚ β”‚ └─ NoneConst{None} [id BJ] β”‚ └─ ExpandDims{axis=0} [id BP] 28 β”‚ └─ Sum{axis=0} [id BQ] 25 β”‚ └─ Add [id BB] 23 β”‚ └─ Β·Β·Β· └─ Mul [id BR] 33 β”œβ”€ [[-0.31622777]] [id BM] └─ SpecifyShape [id BO] 29 └─ Β·Β·Β· Inner graphs: Composite{and(le((0.1 * abs(i0)), 1e-09), invert(or(isnan((0.1 * i0)), isinf((0.1 * i0)))))} [id E] ← AND [id BS] 'o0' β”œβ”€ LE [id BT] β”‚ β”œβ”€ mul [id BU] β”‚ β”‚ β”œβ”€ t0{0.1} [id BV] β”‚ β”‚ └─ Abs [id BW] β”‚ β”‚ └─ i0 [id BX] β”‚ └─ 1e-09 [id BY] └─ Invert [id BZ] └─ OR [id CA] β”œβ”€ IsNan [id CB] β”‚ └─ mul [id CC] 't9' β”‚ β”œβ”€ t0{0.1} [id BV] β”‚ └─ i0 [id BX] └─ IsInf [id CD] └─ mul [id CC] 't9' └─ Β·Β·Β· Composite{...} [id G] ← sub [id CE] 'o0' β”œβ”€ i0 [id CF] └─ i1 [id CG] ← neg [id CH] 'o1' └─ sub [id CE] 'o0' └─ Β·Β·Β· ← sub [id CI] 'o2' β”œβ”€ mul [id CJ] β”‚ β”œβ”€ -0.5 [id CK] β”‚ └─ sqr [id CL] β”‚ └─ sub [id CE] 'o0' β”‚ └─ Β·Β·Β· └─ 0.7443402118957849 [id CM] Composite{...} [id N] ← mul [id CN] 'o0' β”œβ”€ 0.07597469266479578 [id CO] └─ i0 [id CP] ← sub [id CQ] 'o1' β”œβ”€ mul [id CN] 'o0' β”‚ └─ Β·Β·Β· └─ mul [id CR] β”œβ”€ 0.31622776601683794 [id CS] └─ i0 [id CP] Composite{...} [id Q] ← mul [id CT] 'o0' β”œβ”€ 0.07597469266479578 [id CU] └─ i0 [id CV] ← sub [id CW] 'o1' β”œβ”€ mul [id CT] 'o0' β”‚ └─ Β·Β·Β· └─ mul [id CX] β”œβ”€ 0.31622776601683794 [id CY] └─ i0 [id CV] Composite{and(le((0.1 * abs(i0)), 1e-09), invert(or(isnan((0.1 * i0)), isinf((0.1 * i0)))))} [id U] ← AND [id CZ] 'o0' β”œβ”€ LE [id DA] β”‚ β”œβ”€ mul [id DB] β”‚ β”‚ β”œβ”€ t7{0.1} [id DC] β”‚ β”‚ └─ Abs [id DD] β”‚ β”‚ └─ i0 [id DE] β”‚ └─ 1e-09 [id DF] └─ Invert [id DG] └─ OR [id DH] β”œβ”€ IsNan [id DI] β”‚ └─ mul [id DJ] 't11' β”‚ β”œβ”€ t7{0.1} [id DC] β”‚ └─ i0 [id DE] └─ IsInf [id DK] └─ mul [id DJ] 't11' └─ Β·Β·Β· Composite{(0.07597469266479578 * (i0 + i1))} [id BH] ← mul [id DL] 'o0' β”œβ”€ 0.07597469266479578 [id DM] └─ add [id DN] β”œβ”€ i0 [id DO] └─ i1 [id DP] Composite{(0.07597469266479578 * (i0 - i1))} [id BN] ← mul [id DQ] 'o0' β”œβ”€ 0.07597469266479578 [id DR] └─ sub [id DS] β”œβ”€ i0 [id DT] └─ i1 [id DU] ```

I think I like this PR :D

Maybe we can also add an implementation for the beta and the logit-normal dists:

@pm.logprob.abstract._transformed_logprob.register(pt.random.basic.BetaRV, pm.distributions.transforms.LogOddsTransform)
def transformed_beta_logp(op, transform, unconstrained_value, rv_inputs):
    *_, alpha, beta = rv_inputs
    logit_x = unconstrained_value

    normalizing_factor = pt.gammaln(alpha + beta) - pt.gammaln(alpha) - pt.gammaln(beta) 
    logp = normalizing_factor - pt.log1p(pt.exp(-logit_x)) * (alpha + beta) - logit_x * beta
    return pm.distributions.multivariate.check_parameters(
        logp,
        alpha > 0,
        beta > 0,
        "Alpha and beta parameters must be positive in Beta distribution",
    )

@pm.logprob.abstract._transformed_logprob.register(pm.distributions.continuous.LogitNormalRV, pm.distributions.transforms.LogOddsTransform)
def transformed_logit_normal_logp(op, transform, unconstrained_value, rv_inputs):
    *_, mu, sigma = rv_inputs

    return pm.logp(pm.Normal.dist(mu, sigma), unconstrained_value)
ricardoV94 commented 4 months ago

@aseyboldt perhaps a more fair graph comparison would be with Model(check_bounds=False). The mean switch for instance is just a bound check thing that could equally be removed in a model with the transform and the standard logp

aseyboldt commented 4 months ago

That was with check_bounds=False. For most dists that only disables the checks on the parameters, not the domain of the value. But I don't think that's the main source for the difference either. But this comparison wasn't entirely fair, as the transformation wasn't needed at all anymore with the PR, while in most real applications that would still be needed when the variable is used downstream in the model. I guess this one is a fairer comparision:

with pm.Model(check_bounds=False) as model:
    x = pm.ZeroSumNormal("x", shape=(100, 100), n_zerosum_axes=2)
    pm.Normal("y", mu=x, sigma=1, shape=(100, 100))

With PR:

``` Add [id A] 18 β”œβ”€ Sum{axes=None} [id B] 3 β”‚ └─ Composite{((-0.5 * sqr(i0)) - 0.9189385332046727)} [id C] 'x_zerosum___logprob' 1 β”‚ └─ x_zerosum__ [id D] └─ Sum{axes=None} [id E] 14 └─ Composite{...}.2 [id F] 'y_logprob' 11 β”œβ”€ Join [id G] 10 β”‚ β”œβ”€ 1 [id H] β”‚ β”œβ”€ Sub [id I] 6 β”‚ β”‚ β”œβ”€ Join [id J] 5 β”‚ β”‚ β”‚ β”œβ”€ 0 [id K] β”‚ β”‚ β”‚ β”œβ”€ x_zerosum__ [id D] β”‚ β”‚ β”‚ └─ Composite{...}.1 [id L] 4 β”‚ β”‚ β”‚ └─ ExpandDims{axis=0} [id M] 2 β”‚ β”‚ β”‚ └─ Sum{axis=0} [id N] 0 β”‚ β”‚ β”‚ └─ x_zerosum__ [id D] β”‚ β”‚ └─ Composite{...}.0 [id L] 4 β”‚ β”‚ └─ Β·Β·Β· β”‚ └─ Composite{...}.1 [id O] 9 β”‚ └─ ExpandDims{axis=1} [id P] 8 β”‚ └─ Sum{axis=1} [id Q] 7 β”‚ └─ Sub [id I] 6 β”‚ └─ Β·Β·Β· β”œβ”€ Composite{...}.0 [id O] 9 β”‚ └─ Β·Β·Β· └─ y [id R] Composite{((-i0) + i1 + i2 + i3)} [id S] 'x_zerosum___grad' 28 β”œβ”€ x_zerosum__ [id D] β”œβ”€ Split{2}.0 [id T] 22 β”‚ β”œβ”€ Add [id U] 21 β”‚ β”‚ β”œβ”€ SpecifyShape [id V] 17 β”‚ β”‚ β”‚ β”œβ”€ Split{2}.0 [id W] 13 β”‚ β”‚ β”‚ β”‚ β”œβ”€ Composite{...}.0 [id F] 11 β”‚ β”‚ β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”‚ β”‚ β”œβ”€ 1 [id H] β”‚ β”‚ β”‚ β”‚ └─ [99 1] [id X] β”‚ β”‚ β”‚ β”œβ”€ 100 [id Y] β”‚ β”‚ β”‚ └─ 99 [id Z] β”‚ β”‚ β”œβ”€ Composite{(0.00909090909090909 * (i0 - i1))} [id BA] 19 β”‚ β”‚ β”‚ β”œβ”€ SpecifyShape [id BB] 16 β”‚ β”‚ β”‚ β”‚ β”œβ”€ Split{2}.1 [id W] 13 β”‚ β”‚ β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”‚ β”‚ β”œβ”€ NoneConst{None} [id BC] β”‚ β”‚ β”‚ β”‚ └─ 1 [id H] β”‚ β”‚ β”‚ └─ ExpandDims{axis=1} [id BD] 15 β”‚ β”‚ β”‚ └─ Sum{axis=1} [id BE] 12 β”‚ β”‚ β”‚ └─ Composite{...}.0 [id F] 11 β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ └─ Mul [id BF] 20 β”‚ β”‚ β”œβ”€ [[-0.1]] [id BG] β”‚ β”‚ └─ SpecifyShape [id BB] 16 β”‚ β”‚ └─ Β·Β·Β· β”‚ β”œβ”€ 0 [id K] β”‚ └─ [99 1] [id X] β”œβ”€ Composite{(0.00909090909090909 * (i0 - i1))} [id BH] 27 β”‚ β”œβ”€ SpecifyShape [id BI] 24 β”‚ β”‚ β”œβ”€ Split{2}.1 [id T] 22 β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”œβ”€ 1 [id H] β”‚ β”‚ └─ NoneConst{None} [id BC] β”‚ └─ ExpandDims{axis=0} [id BJ] 25 β”‚ └─ Sum{axis=0} [id BK] 23 β”‚ └─ Add [id U] 21 β”‚ └─ Β·Β·Β· └─ Mul [id BL] 26 β”œβ”€ [[-0.1]] [id BG] └─ SpecifyShape [id BI] 24 └─ Β·Β·Β· Composite{...}.1 [id F] 'y_grad' 11 └─ Β·Β·Β· Inner graphs: Composite{((-0.5 * sqr(i0)) - 0.9189385332046727)} [id C] ← sub [id BM] 'o0' β”œβ”€ mul [id BN] β”‚ β”œβ”€ -0.5 [id BO] β”‚ └─ sqr [id BP] β”‚ └─ i0 [id BQ] └─ 0.9189385332046727 [id BR] Composite{...} [id F] ← sub [id BS] 'o0' β”œβ”€ i2 [id BT] └─ sub [id BU] β”œβ”€ i0 [id BV] └─ i1 [id BW] ← neg [id BX] 'o1' └─ sub [id BS] 'o0' └─ Β·Β·Β· ← sub [id BY] 'o2' β”œβ”€ mul [id BZ] β”‚ β”œβ”€ -0.5 [id CA] β”‚ └─ sqr [id CB] β”‚ └─ sub [id BS] 'o0' β”‚ └─ Β·Β·Β· └─ 0.9189385332046727 [id CC] Composite{...} [id L] ← mul [id CD] 'o0' β”œβ”€ 0.00909090909090909 [id CE] └─ i0 [id CF] ← sub [id CG] 'o1' β”œβ”€ mul [id CD] 'o0' β”‚ └─ Β·Β·Β· └─ mul [id CH] β”œβ”€ 0.1 [id CI] └─ i0 [id CF] Composite{...} [id O] ← mul [id CJ] 'o0' β”œβ”€ 0.00909090909090909 [id CK] └─ i0 [id CL] ← sub [id CM] 'o1' β”œβ”€ mul [id CJ] 'o0' β”‚ └─ Β·Β·Β· └─ mul [id CN] β”œβ”€ 0.1 [id CO] └─ i0 [id CL] Composite{((-i0) + i1 + i2 + i3)} [id S] ← add [id CP] 'o0' β”œβ”€ neg [id CQ] β”‚ └─ i0 [id CR] β”œβ”€ i1 [id CS] β”œβ”€ i2 [id CT] └─ i3 [id CU] Composite{(0.00909090909090909 * (i0 - i1))} [id BA] ← mul [id CV] 'o0' β”œβ”€ 0.00909090909090909 [id CW] └─ sub [id CX] β”œβ”€ i0 [id CY] └─ i1 [id CZ] Composite{(0.00909090909090909 * (i0 - i1))} [id BH] ← mul [id DA] 'o0' β”œβ”€ 0.00909090909090909 [id DB] └─ sub [id DC] β”œβ”€ i0 [id DD] └─ i1 [id DE] ```

Without PR

``` Composite{(switch(i0, i1, i2) + i3)} [id A] 33 β”œβ”€ All{axes=None} [id B] 29 β”‚ └─ MakeVector{dtype='bool'} [id C] 26 β”‚ β”œβ”€ All{axes=None} [id D] 24 β”‚ β”‚ └─ Composite{and(le((0.01 * abs(i0)), 1e-09), invert(or(isnan((0.01 * i0)), isinf((0.01 * i0)))))} [id E] 20 β”‚ β”‚ └─ Sum{axis=1} [id F] 15 β”‚ β”‚ └─ Composite{...}.0 [id G] 9 β”‚ β”‚ β”œβ”€ Join [id H] 8 β”‚ β”‚ β”‚ β”œβ”€ 1 [id I] β”‚ β”‚ β”‚ β”œβ”€ Sub [id J] 4 β”‚ β”‚ β”‚ β”‚ β”œβ”€ Join [id K] 3 β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ 0 [id L] β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ x_zerosum__ [id M] β”‚ β”‚ β”‚ β”‚ β”‚ └─ Composite{...}.1 [id N] 2 β”‚ β”‚ β”‚ β”‚ β”‚ └─ ExpandDims{axis=0} [id O] 1 β”‚ β”‚ β”‚ β”‚ β”‚ └─ Sum{axis=0} [id P] 0 β”‚ β”‚ β”‚ β”‚ β”‚ └─ x_zerosum__ [id M] β”‚ β”‚ β”‚ β”‚ └─ Composite{...}.0 [id N] 2 β”‚ β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”‚ └─ Composite{...}.1 [id Q] 7 β”‚ β”‚ β”‚ └─ ExpandDims{axis=1} [id R] 6 β”‚ β”‚ β”‚ └─ Sum{axis=1} [id S] 5 β”‚ β”‚ β”‚ └─ Sub [id J] 4 β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”œβ”€ Composite{...}.0 [id Q] 7 β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ └─ y [id T] β”‚ └─ All{axes=None} [id U] 23 β”‚ └─ Composite{and(le((0.01 * abs(i0)), 1e-09), invert(or(isnan((0.01 * i0)), isinf((0.01 * i0)))))} [id V] 19 β”‚ └─ Sum{axis=0} [id W] 14 β”‚ └─ Composite{...}.0 [id G] 9 β”‚ └─ Β·Β·Β· β”œβ”€ Sum{axes=None} [id X] 13 β”‚ └─ Composite{...}.3 [id G] 9 β”‚ └─ Β·Β·Β· β”œβ”€ -inf [id Y] └─ Sum{axes=None} [id Z] 12 └─ Composite{...}.4 [id G] 'y_logprob' 9 └─ Β·Β·Β· Add [id BA] 'x_zerosum___grad' 36 β”œβ”€ SpecifyShape [id BB] 32 β”‚ β”œβ”€ Split{2}.0 [id BC] 28 β”‚ β”‚ β”œβ”€ Add [id BD] 25 β”‚ β”‚ β”‚ β”œβ”€ SpecifyShape [id BE] 18 β”‚ β”‚ β”‚ β”‚ β”œβ”€ Split{2}.0 [id BF] 11 β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ Composite{...}.1 [id G] 9 β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ 1 [id I] β”‚ β”‚ β”‚ β”‚ β”‚ └─ [99 1] [id BG] β”‚ β”‚ β”‚ β”‚ β”œβ”€ 100 [id BH] β”‚ β”‚ β”‚ β”‚ └─ 99 [id BI] β”‚ β”‚ β”‚ β”œβ”€ Composite{(0.00909090909090909 * (i0 - i1))} [id BJ] 21 β”‚ β”‚ β”‚ β”‚ β”œβ”€ SpecifyShape [id BK] 17 β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ Split{2}.1 [id BF] 11 β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€ NoneConst{None} [id BL] β”‚ β”‚ β”‚ β”‚ β”‚ └─ 1 [id I] β”‚ β”‚ β”‚ β”‚ └─ ExpandDims{axis=1} [id BM] 16 β”‚ β”‚ β”‚ β”‚ └─ Sum{axis=1} [id BN] 10 β”‚ β”‚ β”‚ β”‚ └─ Composite{...}.1 [id G] 9 β”‚ β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”‚ └─ Mul [id BO] 22 β”‚ β”‚ β”‚ β”œβ”€ [[-0.1]] [id BP] β”‚ β”‚ β”‚ └─ SpecifyShape [id BK] 17 β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”œβ”€ 0 [id L] β”‚ β”‚ └─ [99 1] [id BG] β”‚ β”œβ”€ 99 [id BI] β”‚ └─ 99 [id BI] β”œβ”€ Composite{(0.00909090909090909 * (i0 - i1))} [id BQ] 34 β”‚ β”œβ”€ SpecifyShape [id BR] 31 β”‚ β”‚ β”œβ”€ Split{2}.1 [id BC] 28 β”‚ β”‚ β”‚ └─ Β·Β·Β· β”‚ β”‚ β”œβ”€ 1 [id I] β”‚ β”‚ └─ NoneConst{None} [id BL] β”‚ └─ ExpandDims{axis=0} [id BS] 30 β”‚ └─ Sum{axis=0} [id BT] 27 β”‚ └─ Add [id BD] 25 β”‚ └─ Β·Β·Β· └─ Mul [id BU] 35 β”œβ”€ [[-0.1]] [id BP] └─ SpecifyShape [id BR] 31 └─ Β·Β·Β· Composite{...}.2 [id G] 'y_grad' 9 └─ Β·Β·Β· Inner graphs: Composite{(switch(i0, i1, i2) + i3)} [id A] ← add [id BV] 'o0' β”œβ”€ Switch [id BW] β”‚ β”œβ”€ i0 [id BX] β”‚ β”œβ”€ i1 [id BY] β”‚ └─ i2 [id BZ] └─ i3 [id CA] Composite{and(le((0.01 * abs(i0)), 1e-09), invert(or(isnan((0.01 * i0)), isinf((0.01 * i0)))))} [id E] ← AND [id CB] 'o0' β”œβ”€ LE [id CC] β”‚ β”œβ”€ mul [id CD] β”‚ β”‚ β”œβ”€ t7{0.01} [id CE] β”‚ β”‚ └─ Abs [id CF] β”‚ β”‚ └─ i0 [id CG] β”‚ └─ 1e-09 [id CH] └─ Invert [id CI] └─ OR [id CJ] β”œβ”€ IsNan [id CK] β”‚ └─ mul [id CL] 't4' β”‚ β”œβ”€ t7{0.01} [id CE] β”‚ └─ i0 [id CG] └─ IsInf [id CM] └─ mul [id CL] 't4' └─ Β·Β·Β· Composite{...} [id G] ← sub [id CN] 'o0' β”œβ”€ i0 [id CO] └─ i1 [id CP] ← sub [id CQ] 'o1' β”œβ”€ sub [id CR] 't13' β”‚ β”œβ”€ i2 [id CS] β”‚ └─ sub [id CN] 'o0' β”‚ └─ Β·Β·Β· └─ sub [id CN] 'o0' └─ Β·Β·Β· ← neg [id CT] 'o2' └─ sub [id CR] 't13' └─ Β·Β·Β· ← sub [id CU] 'o3' β”œβ”€ mul [id CV] β”‚ β”œβ”€ t6{-0.5} [id CW] β”‚ └─ sqr [id CX] β”‚ └─ sub [id CN] 'o0' β”‚ └─ Β·Β·Β· └─ 0.9006516563938997 [id CY] ← sub [id CZ] 'o4' β”œβ”€ mul [id DA] β”‚ β”œβ”€ t6{-0.5} [id CW] β”‚ └─ sqr [id DB] β”‚ └─ sub [id CR] 't13' β”‚ └─ Β·Β·Β· └─ 0.9189385332046727 [id DC] Composite{...} [id N] ← mul [id DD] 'o0' β”œβ”€ 0.00909090909090909 [id DE] └─ i0 [id DF] ← sub [id DG] 'o1' β”œβ”€ mul [id DD] 'o0' β”‚ └─ Β·Β·Β· └─ mul [id DH] β”œβ”€ 0.1 [id DI] └─ i0 [id DF] Composite{...} [id Q] ← mul [id DJ] 'o0' β”œβ”€ 0.00909090909090909 [id DK] └─ i0 [id DL] ← sub [id DM] 'o1' β”œβ”€ mul [id DJ] 'o0' β”‚ └─ Β·Β·Β· └─ mul [id DN] β”œβ”€ 0.1 [id DO] └─ i0 [id DL] Composite{and(le((0.01 * abs(i0)), 1e-09), invert(or(isnan((0.01 * i0)), isinf((0.01 * i0)))))} [id V] ← AND [id DP] 'o0' β”œβ”€ LE [id DQ] β”‚ β”œβ”€ mul [id DR] β”‚ β”‚ β”œβ”€ t3{0.01} [id DS] β”‚ β”‚ └─ Abs [id DT] β”‚ β”‚ └─ i0 [id DU] β”‚ └─ 1e-09 [id DV] └─ Invert [id DW] └─ OR [id DX] β”œβ”€ IsNan [id DY] β”‚ └─ mul [id DZ] 't10' β”‚ β”œβ”€ t3{0.01} [id DS] β”‚ └─ i0 [id DU] └─ IsInf [id EA] └─ mul [id DZ] 't10' └─ Β·Β·Β· Composite{(0.00909090909090909 * (i0 - i1))} [id BJ] ← mul [id EB] 'o0' β”œβ”€ 0.00909090909090909 [id EC] └─ sub [id ED] β”œβ”€ i0 [id EE] └─ i1 [id EF] Composite{(0.00909090909090909 * (i0 - i1))} [id BQ] ← mul [id EG] 'o0' β”œβ”€ 0.00909090909090909 [id EH] └─ sub [id EI] β”œβ”€ i0 [id EJ] └─ i1 [id EK] ```

That's 140 vs 200 lines. Got to admit that I'm not 100% sure what the difference really is though...

It's also a bit faster, with 70ΞΌs vs 90ΞΌs.

ricardoV94 commented 4 months ago

That was with check_bounds=False

Then I'm surprised the isclose(mean, 0) switch was not removed

aseyboldt commented 4 months ago

You are right, it seems check_bounds=False didn't do anything, because I wasn't compiling the pytensor functions in a model context. I had no idea I had to do that...