jax-ml / bayeux

State of the art inference for your bayesian models.
https://jax-ml.github.io/bayeux/
Apache License 2.0
158 stars 6 forks source link

Error for Bambi/PYMC Model when using TFP #29

Closed zwelitunyiswa closed 8 months ago

zwelitunyiswa commented 8 months ago

When I run the following, I get the following error:

TypeError: float() argument must be a string or a real number, not 'ShapedArray'

import bayeux as bx
import bambi as bmb
import pymc as pm
import pandas as pd
import jax
import arviz as az

dist = pm.Normal.dist(mu=100, sigma=30)

draws = pm.draw(dist, draws=1_000, random_seed=1000)

df = pd.DataFrame(data=draws, columns=['heights'])

formula = bmb.Formula('heights ~ 1')

model = bmb.Model(formula=formula, family='gaussian', data=df)

model.build()

bx_model = bx.Model.from_pymc(model.backend.model)

idata = bx_model.mcmc.tfp_nuts(seed=jax.random.key(0))

az.summary(idata)

The traceback is as follows:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], [line 22](vscode-notebook-cell:?execution_count=1&line=22)
     [18](vscode-notebook-cell:?execution_count=1&line=18) model.build()
     [20](vscode-notebook-cell:?execution_count=1&line=20) bx_model = bx.Model.from_pymc(model.backend.model)
---> [22](vscode-notebook-cell:?execution_count=1&line=22) idata = bx_model.mcmc.tfp_nuts(seed=jax.random.key(0))
     [24](vscode-notebook-cell:?execution_count=1&line=24) az.summary(idata)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:205](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:205), in _TFPBase.__call__(self, seed, **kwargs)
    [194](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:194) initial_running_variance = [
    [195](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:195)     tfp.experimental.stats.sample_stats.RunningVariance.from_stats(
    [196](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:196)         num_samples=jnp.array(1, part.dtype),
    [197](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:197)         mean=jnp.zeros_like(part),
    [198](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:198)         variance=jnp.ones_like(part))
    [199](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:199)     for part in initial_transformed_position]
    [201](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:201) # The public API expects a JointDistribution. Much of the above is adapted
    [202](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:202) # from the source code for
    [203](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:203) # `tfp.experimental.mcmc.windowed_adaptive_{nuts|hmc}`, but handling a raw
    [204](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:204) # log density, and doing the structure flattening with `jax.tree_utils`.
--> [205](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:205) draws, trace = tfp.experimental.mcmc.windowed_sampling._do_sampling(
    [206](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:206)     kind=self.algorithm,
    [207](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:207)     proposal_kernel_kwargs=proposal_kernel_kwargs,
    [208](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:208)     dual_averaging_kwargs=dual_averaging_kwargs,
    [209](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:209)     num_draws=extra_parameters["num_draws"],
    [210](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:210)     num_burnin_steps=extra_parameters["num_adaptation_steps"],
    [211](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:211)     initial_position=initial_transformed_position,
    [212](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:212)     initial_running_variance=initial_running_variance,
    [213](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:213)     bijector=None,
    [214](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:214)     trace_fn=_TRACE_FNS[self.algorithm],
    [215](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:215)     return_final_kernel_results=False,
    [216](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:216)     chain_axis_names=None,
    [217](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:217)     shard_axis_names=None,
    [218](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:218)     seed=sample_key)
    [220](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:220) draws = self.transform_fn(jax.tree_util.tree_unflatten(treedef, draws))
    [221](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:221) if extra_parameters["return_pytree"]:

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:551](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:551), in _do_sampling(kind, proposal_kernel_kwargs, dual_averaging_kwargs, num_draws, num_burnin_steps, initial_position, initial_running_variance, trace_fn, bijector, return_final_kernel_results, seed, chain_axis_names, shard_axis_names)
    [543](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:543) """Sample from base HMC kernel."""
    [544](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:544) kernel = make_windowed_adapt_kernel(
    [545](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:545)     kind=kind,
    [546](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:546)     proposal_kernel_kwargs=proposal_kernel_kwargs,
   (...)
    [549](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:549)     chain_axis_names=chain_axis_names,
    [550](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:550)     shard_axis_names=shard_axis_names)
--> [551](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:551) return sample.sample_chain(
    [552](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:552)     num_draws,
    [553](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:553)     initial_position,
    [554](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:554)     kernel=kernel,
    [555](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:555)     num_burnin_steps=num_burnin_steps,
    [556](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:556)     # pylint: disable=g-long-lambda
    [557](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:557)     trace_fn=lambda state, pkr: trace_fn(
    [558](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:558)         state, bijector, pkr.step <= dual_averaging_kwargs[
    [559](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:559)             'num_adaptation_steps'], pkr.inner_results.inner_results.
    [560](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:560)         inner_results),
    [561](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:561)     # pylint: enable=g-long-lambda
    [562](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:562)     return_final_kernel_results=return_final_kernel_results,
    [563](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:563)     seed=seed)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:359](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:359), in sample_chain(num_results, current_state, previous_kernel_results, kernel, num_burnin_steps, num_steps_between_results, trace_fn, return_final_kernel_results, parallel_iterations, seed, name)
    [352](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:352)   seed, next_state, current_kernel_results = loop_util.smart_for_loop(
    [353](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:353)       loop_num_iter=num_steps,
    [354](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:354)       body_fn=_seeded_one_step,
    [355](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:355)       initial_loop_vars=list(seed_state_and_results),
    [356](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:356)       parallel_iterations=parallel_iterations)
    [357](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:357)   return seed, next_state, current_kernel_results
--> [359](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:359) (_, _, final_kernel_results), (all_states, trace) = loop_util.trace_scan(
    [360](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:360)     loop_fn=_trace_scan_fn,
    [361](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:361)     initial_state=(seed, current_state, previous_kernel_results),
    [362](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:362)     elems=tf.one_hot(
    [363](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:363)         indices=0,
    [364](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:364)         depth=num_results,
    [365](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:365)         on_value=1 + num_burnin_steps,
    [366](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:366)         off_value=1 + num_steps_between_results,
    [367](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:367)         dtype=tf.int32),
    [368](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:368)     # pylint: disable=g-long-lambda
    [369](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:369)     trace_fn=lambda seed_state_and_results: (
    [370](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:370)         seed_state_and_results[1], trace_fn(*seed_state_and_results[1:])),
    [371](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:371)     # pylint: enable=g-long-lambda
    [372](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:372)     parallel_iterations=parallel_iterations)
    [374](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:374) if return_final_kernel_results:
    [375](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:375)   return CheckpointableStatesAndTrace(
    [376](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:376)       all_states=all_states,
    [377](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:377)       trace=trace,
    [378](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:378)       final_kernel_results=final_kernel_results)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:232](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:232), in trace_scan(loop_fn, initial_state, elems, trace_fn, trace_criterion_fn, static_trace_allocation_size, condition_fn, parallel_iterations, name)
    [224](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:224)   trace_arrays, num_steps_traced = ps.cond(
    [225](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:225)       trace_criterion_fn(state) if trace_criterion_fn else True,
    [226](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:226)       lambda: (trace_one_step(num_steps_traced, trace_arrays, state),  # pylint: disable=g-long-lambda
    [227](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:227)                num_steps_traced + 1),
    [228](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:228)       lambda: (trace_arrays, num_steps_traced))
    [230](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:230)   return i + 1, state, num_steps_traced, trace_arrays
--> [232](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:232) _, final_state, _, trace_arrays = tf.while_loop(
    [233](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:233)     cond=condition_fn if condition_fn is not None else lambda *_: True,
    [234](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:234)     body=_body,
    [235](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:235)     loop_vars=(0, initial_state, 0, trace_arrays),
    [236](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:236)     maximum_iterations=length,
    [237](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:237)     parallel_iterations=parallel_iterations)
    [239](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:239) # unflatten
    [240](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:240) stacked_trace = tf.nest.pack_sequence_as(
    [241](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:241)     initial_trace, [ta.stack() for ta in trace_arrays],
    [242](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:242)     expand_composites=True)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:102](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:102), in _while_loop_jax(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
     [99](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:99)       args = pack_body(body(*args))
    [100](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:100)     return args, ()
--> [102](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:102)   loop_vars, _ = lax.scan(
    [103](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:103)       override_body_fn, loop_vars, xs=None, length=maximum_iterations)
    [104](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:104)   return loop_vars
    [105](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:105) else:

    [... skipping hidden 9 frame]

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:99](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:99), in _while_loop_jax.<locals>.override_body_fn(args, _)
     [96](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:96)   args = lax.cond(c, args, lambda args: pack_body(body(*args)), args,
     [97](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:97)                   lambda args: args)
     [98](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:98) elif sc:
---> [99](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:99)   args = pack_body(body(*args))
    [100](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:100) return args, ()

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:222](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:222), in trace_scan.<locals>._body(i, state, num_steps_traced, trace_arrays)
    [220](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:220) def _body(i, state, num_steps_traced, trace_arrays):
    [221](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:221)   elem = elems_array.read(i)
--> [222](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:222)   state = loop_fn(state, elem)
    [224](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:224)   trace_arrays, num_steps_traced = ps.cond(
    [225](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:225)       trace_criterion_fn(state) if trace_criterion_fn else True,
    [226](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:226)       lambda: (trace_one_step(num_steps_traced, trace_arrays, state),  # pylint: disable=g-long-lambda
    [227](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:227)                num_steps_traced + 1),
    [228](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:228)       lambda: (trace_arrays, num_steps_traced))
    [230](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:230)   return i + 1, state, num_steps_traced, trace_arrays

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:352](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:352), in sample_chain.<locals>._trace_scan_fn(seed_state_and_results, num_steps)
    [351](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:351) def _trace_scan_fn(seed_state_and_results, num_steps):
--> [352](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:352)   seed, next_state, current_kernel_results = loop_util.smart_for_loop(
    [353](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:353)       loop_num_iter=num_steps,
    [354](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:354)       body_fn=_seeded_one_step,
    [355](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:355)       initial_loop_vars=list(seed_state_and_results),
    [356](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:356)       parallel_iterations=parallel_iterations)
    [357](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:357)   return seed, next_state, current_kernel_results

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:111](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:111), in smart_for_loop(loop_num_iter, body_fn, initial_loop_vars, parallel_iterations, unroll_threshold, name)
    [101](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:101) if (loop_num_iter_ is None
    [102](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:102)     or tf.executing_eagerly()
    [103](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:103)     # large values for loop_num_iter_ will cause ridiculously slow
   (...)
    [108](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:108)   # Cast to int32 to run the comparison against i in host memory,
    [109](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:109)   # where while/LoopCond needs it.
    [110](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:110)   loop_num_iter = tf.cast(loop_num_iter, dtype=tf.int32)
--> [111](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:111)   return tf.while_loop(
    [112](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:112)       cond=lambda i, *args: i < loop_num_iter,
    [113](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:113)       body=lambda i, *args: [i + 1] + list(body_fn(*args)),
    [114](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:114)       loop_vars=[np.int32(0)] + initial_loop_vars,
    [115](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:115)       parallel_iterations=parallel_iterations
    [116](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:116)   )[1:]
    [117](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:117) result = initial_loop_vars
    [118](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:118) for _ in range(loop_num_iter_):

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:90](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:90), in _while_loop_jax(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
     [88](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:88)   def override_cond_fn(args):
     [89](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:89)     return cond(*args)
---> [90](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:90)   return lax.while_loop(override_cond_fn, override_body_fn, loop_vars)
     [91](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:91) elif back_prop:
     [92](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:92)   def override_body_fn(args, _):

    [... skipping hidden 4 frame]

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/linalg.py:102](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/linalg.py:102), in register_pytrees.<locals>.register.<locals>.unflatten(info, xs)
    [100](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/linalg.py:100) keys, metadata = info
    [101](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/linalg.py:101) parameters = dict(list(zip(keys, xs)), **metadata)
--> [102](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/linalg.py:102) return cls(**parameters)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:171](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:171), in LinearOperatorDiag.__init__(self, diag, is_non_singular, is_self_adjoint, is_positive_definite, is_square, name)
    [161](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:161) parameters = dict(
    [162](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:162)     diag=diag,
    [163](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:163)     is_non_singular=is_non_singular,
   (...)
    [167](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:167)     name=name
    [168](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:168) )
    [170](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:170) with ops.name_scope(name, values=[diag]):
--> [171](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:171)   self._diag = linear_operator_util.convert_nonref_to_tensor(
    [172](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:172)       diag, name="diag")
    [173](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:173)   self._check_diag(self._diag)
    [175](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:175)   # Check and auto-set hints.

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:134](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:134), in convert_nonref_to_tensor(value, dtype, dtype_hint, name)
    [130](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:130)     raise TypeError(
    [131](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:131)         f"Argument `value` must be of dtype `{dtype_name(dtype_base)}` "
    [132](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:132)         f"Received: `{dtype_name(value_dtype_base)}`.")
    [133](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:133)   return value
--> [134](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:134) return ops.convert_to_tensor(
    [135](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:135)     value, dtype=dtype, dtype_hint=dtype_hint, name=name
    [136](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:136) )

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:167](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:167), in _convert_to_tensor(value, dtype, dtype_hint, name)
    [164](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:164)     pass
    [166](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:166) if ret is None:
--> [167](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:167)   ret = conversion_func(value, dtype=dtype)
    [168](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:168) return ret

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:243](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:243), in _default_convert_to_tensor(value, dtype)
    [240](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:240) # If no dtype is provided, we try the inferred dtype and fallback to int64 or
    [241](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:241) # float32 depending on the type of conversion error we see.
    [242](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:242) try:
--> [243](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:243)   return _default_convert_to_tensor_with_dtype(value, inferred_dtype)
    [244](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:244) except _Int64ToInt32Error as e:
    [245](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:245)   return np.array(value, dtype=np.int64)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:286](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:286), in _default_convert_to_tensor_with_dtype(value, dtype, error_if_mismatch)
    [283](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:283) is_arraylike = hasattr(value, 'dtype')
    [284](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:284) if is_arraylike:
    [285](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:285)   # Duck-typed for `onp.array`/`oonp.generic`
--> [286](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:286)   arr = np.array(value)
    [287](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:287)   if dtype is not None:
    [288](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:288)     # arr.astype(None) forces conversion to float64
    [289](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:289)     return arr.astype(dtype)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2158](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2158), in array(object, dtype, copy, order, ndmin)
   [2151](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2151) out: ArrayLike
   [2153](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2153) if all(not isinstance(leaf, Array) for leaf in leaves):
   [2154](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2154)   # TODO(jakevdp): falling back to numpy here fails to overflow for lists
   [2155](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2155)   # containing large integers; see discussion in
   [2156](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2156)   # https://github.com/google/jax/pull/6047. More correct would be to call
   [2157](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2157)   # coerce_to_array on each leaf, but this may have performance implications.
-> [2158](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2158)   out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)  # type: ignore[arg-type]
   [2159](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2159) elif isinstance(object, Array):
   [2160](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2160)   assert object.aval is not None

TypeError: float() argument must be a string or a real number, not 'ShapedArray'
ColCarroll commented 8 months ago

Gosh, you're not going to believe this, but the problem is jax.random.key(0) instead of jax.random.PRNGKey(0). The easiest thing for you is to change the type of jax key you use for now, but it is fixed on tfp-nightly, and will be in the next stable release (here is the tricky fix from @SiegeLordEx)

zwelitunyiswa commented 8 months ago

Yes, changing it to "jax.random.PRNGKey(0)" works! Thank you!

Screenshot 2024-02-21 at 12 36 50 PM