pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
370 stars 109 forks source link

Handle more AdvancedSetSubtensor ops in numba #772

Closed aseyboldt closed 3 days ago

aseyboldt commented 6 months ago

Description

This model runs partially in objectmode:

import pymc as pm
import numpy as np
import pytensor

with pm.Model() as model:
    values = np.array([np.nan, 0.])
    x = pm.Normal("x", observed=values)
    pm.Normal("y", mu=x, observed=np.ones(2))

logp = model.logp()
func = pm.compile_pymc(model.value_vars, logp, mode="NUMBA")
pytensor.dprint(func)
Composite{(i2 + (i1 * sqr(i0)) + i3)} [id A] '__logp' 7
 ├─ DropDims{axis=0} [id B] 6
 │  └─ x_unobserved [id C]
 ├─ -0.5 [id D]
 ├─ -1.8378770664093453 [id E]
 └─ Sum{axes=None} [id F] 5
    └─ SpecifyShape [id G] 'sigma > 0' 4
       ├─ Composite{((-0.5 * sqr((1.0 - i0))) - 0.9189385332046727)} [id H] 3
       │  └─ AdvancedSetSubtensor [id I] 2
       │     ├─ AdvancedSetSubtensor [id J] 1
       │     │  ├─ AllocEmpty{dtype='float64'} [id K] 0
       │     │  │  └─ 2 [id L]
       │     │  ├─ x_unobserved [id C]
       │     │  └─ [ True False] [id M]
       │     ├─ [0.] [id N]
       │     └─ [False  True] [id O]
       └─ 2 [id P]
...

The AdvancedSetSubtensor nodes use object model in numba currently, but I think we should be able to implement those without too much trouble.

aseyboldt commented 6 months ago

Simpler example:

import pytensor.tensor as pt
import pytensor

x = pt.vector("x", shape=(5,))
out = x[np.random.randn(5) > 0].set(1)
func = pytensor.function([x], out, mode="NUMBA")
pytensor.dprint(func)
ricardoV94 commented 5 days ago

Simpler example:

import pytensor.tensor as pt
import pytensor

x = pt.vector("x", shape=(5,))
out = x[np.random.randn(5) > 0].set(1)
func = pytensor.function([x], out, mode="NUMBA")
pytensor.dprint(func)

That's relying on boolean inc_subtensor. After #1106 we could convert that to integer inc_subtensor with nonzero(), although it may be worth implementing a special case for boolean