aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 153 forks source link

Implement `Scan`'s `as_while` in JAX #710

Open sokol11 opened 2 years ago

sokol11 commented 2 years ago

Getting this traceback when trying to run sampling_jax.numpyro_nuts(...) (see below). I installed aesara and pymc from development branch source. This is a Colab machine with Python 3.7. Any help would be greatly appreciated!

UnfilteredStackTrace: TypeError: __init__() missing 1 required positional argument: 'as_while'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)

/usr/local/lib/python3.7/dist-packages/aesara/link/jax/dispatch.py in scan(*outer_inputs)
    418     def scan(*outer_inputs):
    419         scan_args = ScanArgs(
--> 420             list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
    421         )
    422 

TypeError: __init__() missing 1 required positional argument: 'as_while'
ricardoV94 commented 2 years ago

Hi @sokol11 can you provide a small reproducible code snippet that triggers this error?

sokol11 commented 2 years ago

Hi @ricardoV94, and thanks for the fast response. I'm trying to run this regime-switching model (previously written for PyMC3 / theano), which I tweaked a little to comply with the new syntax. The error is triggered right after the .numpyro_nuts(...) call, after jax compiles. Here is the code, sorry it is not very short, but it is a rather complicated model:

First import libraries and load data:

import aesara
import aesara.tensor as tt
import numpy as np
import pymc as pm

# load data
import quandl
googl = quandl.get("WIKI/GOOGL", collapse = "weekly")
googl = (googl.pipe(lambda x: x.assign(l_ac=np.log(x['Adj. Close'])))
              .pipe(lambda x: x.assign(dl_ac=np.hstack([np.nan, np.diff(x['l_ac'])])))
              .query('index > "2010-01-01"')
        )
y = googl['dl_ac'].values * 100 # this is just a 1-D array of floats
yshared = aesara.shared(y)

Second, I run this function definition and test, which was originally written for PyMC3/theano. I don't yet fully understand how the test works, as I am still working my way through understanding the model itself, but importantly, it passes!

aesara.config.test_values = 'raise'

eta_ = tt.dmatrix("eta_")
eta0 = np.random.rand(100, 2)
eta_.tag.test_value = eta0

P = tt.dmatrix("P")
P0 = np.asarray([[.75, .25], [.25, .75]])
P.tag.test_value = P0
xi_ = tt.dscalar("xi_")
xi_.tag.test_value = .75
xi_out = tt.dscalar("xi_out")
xi_out.tag.test_value = 0
ft_out = tt.dscalar("ft_out")
ft_out.tag.test_value = 0

def ft_xit_dt(Eta, ft, Xi, P):
    Xi_ = tt.shape_padleft(Xi)
    xit0 = tt.stack([Xi_, 1 - Xi_], axis=1).T
    ft = tt.sum(tt.dot(xit0 * P, Eta))
    Xi1 = (P[0, 0] * Xi + P[1, 0] * (1 - Xi)) * Eta[0] / ft
    return [ft, Xi1]

([ft, xi], updates) = aesara.scan(ft_xit_dt,
                                  sequences=eta_,
                                  outputs_info=[ft_out, xi_out],
                                  non_sequences=P)

ft_xit_dt_ = aesara.function(inputs=[eta_, ft_out, xi_out, P], outputs=[ft, xi], updates=updates)

ft1, xi1 = ft_xit_dt_(eta0, 0, .75, P0)

ft2 = np.zeros(100)
xi2 = np.zeros(100)

ftfunc = lambda eta,xi: P0[0, 0]*xi*eta[0] +\
                        P0[0, 1]*xi*eta[1] +\
                        P0[1, 1]*(1 - xi)*eta[1] +\
                        P0[1, 0]*(1 - xi)*eta[0]
Eta = eta0[0]
Xi_ = np.asarray([.75])
ft2[0] = ftfunc(Eta, Xi_)
xi2[0] = (P0[0, 0] * Xi_ + P0[1, 0] * (1 - Xi_)) * Eta[0] / ft2[0]

for i in range(1, 100):
    Eta = eta0[i]
    Xi_ = xi2[i - 1]
    ft2[i] = ftfunc(Eta, Xi_)
    xi2[i] = (P0[0, 0] * Xi_ + P0[1, 0] * (1 - Xi_)) * Eta[0] / ft2[i]

np.testing.assert_almost_equal(ft1, ft2)
np.testing.assert_almost_equal(xi1, xi2)

Next, define the model:

with pm.Model() as m:

    # Transition matrix
    p = pm.Beta('p', alpha=10., beta=2., shape=2)
    P = tt.diag(p)
    P = tt.set_subtensor(P[0, 1], 1 - p[0])
    P = tt.set_subtensor(P[1, 0], 1 - p[1])

    # eta
    alpha = pm.Normal('alpha', mu=0., sd=.1, shape=2)
    sigma = pm.HalfCauchy('sigma', beta=1., shape=2)
    eta1 = tt.exp(pm.logp(pm.Normal.dist(mu=alpha[0], sd=sigma[0]), yshared))

    y_tm1_init = pm.Normal('y_init', mu=0., sd=.1)
    rho = pm.Bound('rho', pm.Normal.dist(mu=1.0, sd=.1), lower=0., initval=1.0)
    eta2 = tt.zeros_like(eta1)
    eta2 = tt.set_subtensor(eta2[0], tt.exp(
        pm.logp(pm.Normal.dist(mu=alpha[1] + rho * y_tm1_init, sd=sigma[1]), yshared[0])))
    eta2 = tt.set_subtensor(eta2[1:], tt.exp(
        pm.logp(pm.Normal.dist(mu=alpha[1] + rho * yshared[:-1], sd=sigma[1]), yshared[1:])))

    eta = tt.stack([eta1, eta2], axis=1)

    xi_init = pm.Beta('xi_init', alpha=2., beta=2.)
    ft_out = aesara.shared(0.) # place holder
    ([ft, xi], updates) = aesara.scan(ft_xit_dt,
                                      sequences=eta,
                                      outputs_info=[ft_out, xi_init],
                                      non_sequences=P)

    Xi = pm.Deterministic('Xi', xi)
    # likelihood `target += sum(log(f))`
    pm.Potential('likelihood', tt.sum(tt.log(ft)))

Last, sample (which is when the error is thrown:

from pymc import sampling_jax
with m:
    # trace = pm.sample(2000, tune=2000, target_accept=0.9)
    trace = sampling_jax.sample_numpyro_nuts(2000, tune=2000, target_accept=0.9)

Please let me know if there is any additional info I can provide. Thank you!

brandonwillard commented 2 years ago

Please see the issue template for the other required information.

Otherwise, this example is written in PyMC, so the issue should be posted in that repository first. From there the relevant PyMC code can be assessed and the connection with Aesara can be clarified. That work should result in an Aesara-based MWE that we can use.

sokol11 commented 2 years ago

Thanks @brandonwillard. Opened a pymc issue.

ricardoV94 commented 2 years ago

I found a MWE using at.jacobian:

import aesara
import aesara.tensor as at
import numpy as np

# Need fixed shape tensorvariable, otherwise JAX complilation fails for different reasons
x = at.TensorType("floatX", (2,))("x")
cost = x + 1
jac = at.jacobian(cost, x)
f = aesara.function([x], jac, mode="JAX")

f(np.zeros(2))
Traceback (most recent call last):
  File "/home/ricardo/Documents/Projects/aesara/aesara/link/utils.py", line 179, in streamline_default_f
    thunk()
  File "/home/ricardo/Documents/Projects/aesara/aesara/link/basic.py", line 665, in thunk
    outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/_src/api.py", line 416, in cache_miss
    out_flat = xla.xla_call(
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 627, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 581, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/linear_util.py", line 263, in memoized_fun
    ans = call(fun, *args)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 653, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 665, in lower_xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1542, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1520, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/tmp/tmphi57pfxf", line 4, in jax_funcified_fgraph
    auto_7361 = scan(auto_6537, auto_7245, auto_6537, auto_1051)
  File "/home/ricardo/Documents/Projects/aesara/aesara/link/jax/dispatch.py", line 420, in scan
    scan_args = ScanArgs(
TypeError: __init__() missing 1 required positional argument: 'as_while'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-466ae3aaa455>", line 9, in <module>
    f(np.zeros(2))
  File "/home/ricardo/Documents/Projects/aesara/aesara/compile/function/types.py", line 964, in __call__
    self.fn()
  File "/home/ricardo/Documents/Projects/aesara/aesara/link/utils.py", line 183, in streamline_default_f
    raise_with_op(fgraph, node, thunk)
  File "/home/ricardo/Documents/Projects/aesara/aesara/link/utils.py", line 517, in raise_with_op
    raise exc_value.with_traceback(exc_trace)
  File "/home/ricardo/Documents/Projects/aesara/aesara/link/utils.py", line 179, in streamline_default_f
    thunk()
  File "/home/ricardo/Documents/Projects/aesara/aesara/link/basic.py", line 665, in thunk
    outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/_src/api.py", line 416, in cache_miss
    out_flat = xla.xla_call(
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 627, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 581, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/linear_util.py", line 263, in memoized_fun
    ans = call(fun, *args)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 653, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 665, in lower_xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1542, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1520, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/tmp/tmphi57pfxf", line 4, in jax_funcified_fgraph
    auto_7361 = scan(auto_6537, auto_7245, auto_6537, auto_1051)
  File "/home/ricardo/Documents/Projects/aesara/aesara/link/jax/dispatch.py", line 420, in scan
    scan_args = ScanArgs(
TypeError: __init__() missing 1 required positional argument: 'as_while'
Apply node that caused the error: for{cpu,scan_fn}(TensorConstant{2}, TensorConstant{[0 1]}, TensorConstant{2}, Elemwise{add,no_inplace}.0)
Toposort index: 1
Inputs types: [TensorType(int64, ()), TensorType(int64, (2,)), TensorType(int64, ()), TensorType(float64, (None,))]
Inputs shapes: [(2,)]
Inputs strides: [(8,)]
Inputs values: [array([0., 0.])]
Outputs clients: [['output']]
ricardoV94 commented 2 years ago

It's not just the as_while argument missing. Several other things seem broken in the JAX dispatch, once I started digging around with this example.