Open sokol11 opened 2 years ago
Hi @sokol11 can you provide a small reproducible code snippet that triggers this error?
Hi @ricardoV94, and thanks for the fast response. I'm trying to run this regime-switching model (previously written for PyMC3 / theano), which I tweaked a little to comply with the new syntax. The error is triggered right after the .numpyro_nuts(...)
call, after jax compiles. Here is the code, sorry it is not very short, but it is a rather complicated model:
First import libraries and load data:
import aesara
import aesara.tensor as tt
import numpy as np
import pymc as pm
# load data
import quandl
googl = quandl.get("WIKI/GOOGL", collapse = "weekly")
googl = (googl.pipe(lambda x: x.assign(l_ac=np.log(x['Adj. Close'])))
.pipe(lambda x: x.assign(dl_ac=np.hstack([np.nan, np.diff(x['l_ac'])])))
.query('index > "2010-01-01"')
)
y = googl['dl_ac'].values * 100 # this is just a 1-D array of floats
yshared = aesara.shared(y)
Second, I run this function definition and test, which was originally written for PyMC3/theano. I don't yet fully understand how the test works, as I am still working my way through understanding the model itself, but importantly, it passes!
aesara.config.test_values = 'raise'
eta_ = tt.dmatrix("eta_")
eta0 = np.random.rand(100, 2)
eta_.tag.test_value = eta0
P = tt.dmatrix("P")
P0 = np.asarray([[.75, .25], [.25, .75]])
P.tag.test_value = P0
xi_ = tt.dscalar("xi_")
xi_.tag.test_value = .75
xi_out = tt.dscalar("xi_out")
xi_out.tag.test_value = 0
ft_out = tt.dscalar("ft_out")
ft_out.tag.test_value = 0
def ft_xit_dt(Eta, ft, Xi, P):
Xi_ = tt.shape_padleft(Xi)
xit0 = tt.stack([Xi_, 1 - Xi_], axis=1).T
ft = tt.sum(tt.dot(xit0 * P, Eta))
Xi1 = (P[0, 0] * Xi + P[1, 0] * (1 - Xi)) * Eta[0] / ft
return [ft, Xi1]
([ft, xi], updates) = aesara.scan(ft_xit_dt,
sequences=eta_,
outputs_info=[ft_out, xi_out],
non_sequences=P)
ft_xit_dt_ = aesara.function(inputs=[eta_, ft_out, xi_out, P], outputs=[ft, xi], updates=updates)
ft1, xi1 = ft_xit_dt_(eta0, 0, .75, P0)
ft2 = np.zeros(100)
xi2 = np.zeros(100)
ftfunc = lambda eta,xi: P0[0, 0]*xi*eta[0] +\
P0[0, 1]*xi*eta[1] +\
P0[1, 1]*(1 - xi)*eta[1] +\
P0[1, 0]*(1 - xi)*eta[0]
Eta = eta0[0]
Xi_ = np.asarray([.75])
ft2[0] = ftfunc(Eta, Xi_)
xi2[0] = (P0[0, 0] * Xi_ + P0[1, 0] * (1 - Xi_)) * Eta[0] / ft2[0]
for i in range(1, 100):
Eta = eta0[i]
Xi_ = xi2[i - 1]
ft2[i] = ftfunc(Eta, Xi_)
xi2[i] = (P0[0, 0] * Xi_ + P0[1, 0] * (1 - Xi_)) * Eta[0] / ft2[i]
np.testing.assert_almost_equal(ft1, ft2)
np.testing.assert_almost_equal(xi1, xi2)
Next, define the model:
with pm.Model() as m:
# Transition matrix
p = pm.Beta('p', alpha=10., beta=2., shape=2)
P = tt.diag(p)
P = tt.set_subtensor(P[0, 1], 1 - p[0])
P = tt.set_subtensor(P[1, 0], 1 - p[1])
# eta
alpha = pm.Normal('alpha', mu=0., sd=.1, shape=2)
sigma = pm.HalfCauchy('sigma', beta=1., shape=2)
eta1 = tt.exp(pm.logp(pm.Normal.dist(mu=alpha[0], sd=sigma[0]), yshared))
y_tm1_init = pm.Normal('y_init', mu=0., sd=.1)
rho = pm.Bound('rho', pm.Normal.dist(mu=1.0, sd=.1), lower=0., initval=1.0)
eta2 = tt.zeros_like(eta1)
eta2 = tt.set_subtensor(eta2[0], tt.exp(
pm.logp(pm.Normal.dist(mu=alpha[1] + rho * y_tm1_init, sd=sigma[1]), yshared[0])))
eta2 = tt.set_subtensor(eta2[1:], tt.exp(
pm.logp(pm.Normal.dist(mu=alpha[1] + rho * yshared[:-1], sd=sigma[1]), yshared[1:])))
eta = tt.stack([eta1, eta2], axis=1)
xi_init = pm.Beta('xi_init', alpha=2., beta=2.)
ft_out = aesara.shared(0.) # place holder
([ft, xi], updates) = aesara.scan(ft_xit_dt,
sequences=eta,
outputs_info=[ft_out, xi_init],
non_sequences=P)
Xi = pm.Deterministic('Xi', xi)
# likelihood `target += sum(log(f))`
pm.Potential('likelihood', tt.sum(tt.log(ft)))
Last, sample (which is when the error is thrown:
from pymc import sampling_jax
with m:
# trace = pm.sample(2000, tune=2000, target_accept=0.9)
trace = sampling_jax.sample_numpyro_nuts(2000, tune=2000, target_accept=0.9)
Please let me know if there is any additional info I can provide. Thank you!
Please see the issue template for the other required information.
Otherwise, this example is written in PyMC, so the issue should be posted in that repository first. From there the relevant PyMC code can be assessed and the connection with Aesara can be clarified. That work should result in an Aesara-based MWE that we can use.
Thanks @brandonwillard. Opened a pymc issue.
I found a MWE using at.jacobian
:
import aesara
import aesara.tensor as at
import numpy as np
# Need fixed shape tensorvariable, otherwise JAX complilation fails for different reasons
x = at.TensorType("floatX", (2,))("x")
cost = x + 1
jac = at.jacobian(cost, x)
f = aesara.function([x], jac, mode="JAX")
f(np.zeros(2))
Traceback (most recent call last):
File "/home/ricardo/Documents/Projects/aesara/aesara/link/utils.py", line 179, in streamline_default_f
thunk()
File "/home/ricardo/Documents/Projects/aesara/aesara/link/basic.py", line 665, in thunk
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/_src/api.py", line 416, in cache_miss
out_flat = xla.xla_call(
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 627, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 581, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/linear_util.py", line 263, in memoized_fun
ans = call(fun, *args)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 653, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 665, in lower_xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1542, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1520, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/tmphi57pfxf", line 4, in jax_funcified_fgraph
auto_7361 = scan(auto_6537, auto_7245, auto_6537, auto_1051)
File "/home/ricardo/Documents/Projects/aesara/aesara/link/jax/dispatch.py", line 420, in scan
scan_args = ScanArgs(
TypeError: __init__() missing 1 required positional argument: 'as_while'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-466ae3aaa455>", line 9, in <module>
f(np.zeros(2))
File "/home/ricardo/Documents/Projects/aesara/aesara/compile/function/types.py", line 964, in __call__
self.fn()
File "/home/ricardo/Documents/Projects/aesara/aesara/link/utils.py", line 183, in streamline_default_f
raise_with_op(fgraph, node, thunk)
File "/home/ricardo/Documents/Projects/aesara/aesara/link/utils.py", line 517, in raise_with_op
raise exc_value.with_traceback(exc_trace)
File "/home/ricardo/Documents/Projects/aesara/aesara/link/utils.py", line 179, in streamline_default_f
thunk()
File "/home/ricardo/Documents/Projects/aesara/aesara/link/basic.py", line 665, in thunk
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/_src/api.py", line 416, in cache_miss
out_flat = xla.xla_call(
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/core.py", line 627, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 581, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/linear_util.py", line 263, in memoized_fun
ans = call(fun, *args)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 653, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/xla.py", line 665, in lower_xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1542, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1520, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/ricardo/Documents/Projects/aesara/venv/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/tmphi57pfxf", line 4, in jax_funcified_fgraph
auto_7361 = scan(auto_6537, auto_7245, auto_6537, auto_1051)
File "/home/ricardo/Documents/Projects/aesara/aesara/link/jax/dispatch.py", line 420, in scan
scan_args = ScanArgs(
TypeError: __init__() missing 1 required positional argument: 'as_while'
Apply node that caused the error: for{cpu,scan_fn}(TensorConstant{2}, TensorConstant{[0 1]}, TensorConstant{2}, Elemwise{add,no_inplace}.0)
Toposort index: 1
Inputs types: [TensorType(int64, ()), TensorType(int64, (2,)), TensorType(int64, ()), TensorType(float64, (None,))]
Inputs shapes: [(2,)]
Inputs strides: [(8,)]
Inputs values: [array([0., 0.])]
Outputs clients: [['output']]
It's not just the as_while
argument missing. Several other things seem broken in the JAX dispatch, once I started digging around with this example.
Getting this traceback when trying to run
sampling_jax.numpyro_nuts(...)
(see below). I installed aesara and pymc from development branch source. This is a Colab machine with Python 3.7. Any help would be greatly appreciated!