Closed amifalk closed 3 months ago
Hi @amifalk, those autoguides are not designed to be composed with vmap after the construction because it needs initialization (to inspect the model and generate something like prototype_trace etc.). Something like this will work
def init(...):
guide = AutoDelta(...)
return guide.init(...)
init_state = jax.vmap(init)(...)
I think there's still an issue here. When svi.init
is called, it initializes both the model and the guide for the first time, which should set up the prototype trace. Batched model fitting works with AutoGuides if I vmap
the SVI methods in all cases except when there is both a deterministic site and the guide is based on a blocked model.
The suggested approach yields the same error as before (though AutoGuides are not registered as pytrees so they cannot be returned after calling vmap).
def guide_init(rng_seed):
guide = AutoDelta(block(seed(model, rng_seed=rng_seed), hide=['b']))
seed(guide, rng_seed=rng_seed)()
return
keys = random.split(random.PRNGKey(0))
jax.vmap(guide_init)(keys) # this works
def guide_init_deterministic(rng_seed):
guide = AutoDelta(block(seed(model_w_deterministic, rng_seed=rng_seed), hide=['b']))
seed(guide, rng_seed=rng_seed)()
return
keys = random.split(random.PRNGKey(0))
jax.vmap(guide_init_deterministic)(keys) # tracer error
@fehiepsi I've traced the source to this while loop. If I set _DISABLE_CONTROL_FLOW_PRIM = True
, vmapping the svi.init method works. However, vmapping the guide initialization yields a new error in the while loop:
This BatchTracer with object id 140305711093520 was created on line:
/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/infer/util.py:357:15 (find_valid_initial_params.<locals>.cond_fn)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
If we use Python while loop, then the condition needs to be a Python value like True or False. having a jax object there won't work. What is your usage case by the way?
I have a blocked model with a deterministic site that I'm trying to perform some simulation studies on. I want to see how variations in the structure of the dataset / model hyperparameters affect the performance, and I also want to be able to select the best result over multiple initializations. It's very slow to do this sequentially (for a small grid of hyperparams it took around 40 minutes), but after vmapping/pmapping with GPU I can get the entire grid to run in parallel. In my case it reduced the fitting time to 7 seconds.
Unfortunately, if I try to vmap the blocked model with deterministic sites present, it throws this error, so I have to instead recompute the deterministic sites at the end of model fitting.
In my case, I need to block the model to define an AutoGuide that is compatible with enumeration (blocking out the enumerated sites), but this would likely also be a problem for people using AutoGuideList.
I think you can do something like
def run_svi(...):
svi = ...
svi_result = svi.run(...)
return svi_result
svi_results = vmap(run_svi)(...)
Unfortunately this still seems to throw the same tracer error.
def run_svi(key):
optimizer = numpyro.optim.Adam(step_size=.01)
guide = AutoDelta(block(seed(model, rng_seed=0), hide=['b']))
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
return svi.run(key, 100, progress_bar=False)
def run_svi_deterministic(key):
optimizer = numpyro.optim.Adam(step_size=.01)
guide = AutoDelta(block(seed(model_w_deterministic, rng_seed=0), hide=['b']))
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
return svi.run(key, 100, progress_bar=False)
keys = random.split(random.PRNGKey(0), 2)
jax.vmap(run_svi)(keys) # works
jax.vmap(run_svi_deterministic)(keys) # tracer error from the while loop in find_valid_initial_params
Can we make it so that AutoGuides only collect non-enumerated model sample sites? This wouldn't fix the problem for all blocked models, but it would make collecting deterministic sites possible under batched svi for my use-case.
I think this would only have to be a one-liner change here-ish where we just ignore sample sites in the prototype trace that have site['infer'].get('enumerate') == True
. That would also make the syntax for defining AutoGuides for enumerated models much simpler, e.g. just AutoGuide(model)
, instead of
AutoGuide(block(seed(model, rng_seed=0), hide=["enumerated_site_1", "enumerated_site_2", ...]))
Thanks @amifalk! There is indeed leakage here with the seed handler. I haven't been able to figure out why yet. Posting here for reference
import numpyro
import numpyro.distributions as dist
import jax
def model():
return numpyro.sample('a', dist.Normal(0, 1))
def run(key):
return numpyro.infer.util.initialize_model(key, numpyro.handlers.seed(model, rng_seed=0))[0]
with jax.checking_leaks():
jax.jit(run)(jax.random.PRNGKey(0))
@fehiepsi With that example, I was able to narrow the source of the bug further - thanks! The while loop of _find_valid_params
closes over the seeded model, but it also traces the model during its calls to potential_fn
. I think the fact that the trace is seeing a rng key from the global call to seed
is causing the error. Here's a minimal example.
import numpyro
import numpyro.distributions as dist
import jax
def model():
return numpyro.sample('a', dist.Normal(0, 1))
def run(key):
seeded = numpyro.handlers.seed(model, rng_seed=0)
def cond_fn(state):
i, num = state
return i < 10
def body_fn(state):
i, num = state
numpyro.handlers.trace(seeded).get_trace() # this references the global rng values in a jitted context
# equivalently num = numpyro.handlers.trace(seeded).get_trace()['a']['value'] will raise an error
return (i + 1, num)
return jax.lax.while_loop(cond_fn, body_fn, (0, 0))
with jax.checking_leaks():
jax.jit(run)(jax.random.PRNGKey(0))
You can also verify this by replacing potential_fn
in numpyro.infer.util.find_valid_initial_params
with a placeholder that just returns a constant number.
I think we figured it out. thanks for the examples!
seed(model) is an instance of a seed
class which has mutable state. A fix for it is to close the seeded model into a function like
def seeded_model(*args, **kwargs):
return seed(model, rng_seed=random.PRNGKey(0))(*args, **kwargs)
This way each time we call the model, a new instance of the seed handler will be created. Could you check if it works for your usage case? I'll think of a long term solution (maybe improve docstring for this).
Yes, this fixed it! Not sure if there's any interest in adding to NumPyro, but here's the pattern for batching SVI: https://gist.github.com/amifalk/eb377a243b046105dc00beda79441b22
I came across a very strange bug while trying to vmap the
SVI
class (in order to parallelize model training across multiple initializations + across different datasets of the same shape).A tracer error occurs, but only if the AutoGuide has a site blocked out and there is also a deterministic site in the model. I wonder if this is related to #1657 ?