pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.1k stars 227 forks source link

tracer error in blocked AutoGuide #1753

Closed amifalk closed 3 months ago

amifalk commented 4 months ago

I came across a very strange bug while trying to vmap the SVI class (in order to parallelize model training across multiple initializations + across different datasets of the same shape).

A tracer error occurs, but only if the AutoGuide has a site blocked out and there is also a deterministic site in the model. I wonder if this is related to #1657 ?

import jax
import jax.random as random

import numpyro
import numpyro.distributions as dist
from numpyro.handlers import block, seed
from numpyro.infer import SVI, TraceEnum_ELBO
from numpyro.infer.autoguide import AutoDelta

def model():
    a = numpyro.sample('a', dist.Normal(0, 1))
    b = numpyro.sample('b', dist.Normal(0, 1))

# -- this works --
keys = random.split(random.PRNGKey(0), 2)

optimizer = numpyro.optim.Adam(step_size=.01)
guide = AutoDelta(block(seed(model, rng_seed=0), hide=['b']))
svi = SVI(model, guide, optimizer, loss=TraceEnum_ELBO())

mapped_state = jax.vmap(svi.init)(keys)

def model_w_deterministic():
    a = numpyro.sample('a', dist.Normal(0, 1))
    b = numpyro.sample('b', dist.Normal(0, 1))

    numpyro.deterministic('test', a)

# -- this fails --
optimizer = numpyro.optim.Adam(step_size=.01)
guide = AutoDelta(block(seed(model_w_deterministic, rng_seed=0), hide=['b']))
svi = SVI(model_w_deterministic, guide, optimizer, loss=TraceEnum_ELBO())

mapped_state = jax.vmap(svi.init)(keys)

JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was body_fn at [/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/infer/util.py:358](https://file+.vscode-resource.vscode-cdn.net/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/infer/util.py:358) traced for while_loop.
------------------------------
The leaked intermediate value was created on line [/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/handlers.py:745:8](https://file+.vscode-resource.vscode-cdn.net/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/handlers.py:745:8) (seed.process_message). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
[/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/primitives.py:105:19](https://file+.vscode-resource.vscode-cdn.net/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/primitives.py:105:19) (Messenger.__call__)
<ipython-input-9-38f6b9474998>:17:8 (model_w_deterministic)
[/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/primitives.py:222:10](https://file+.vscode-resource.vscode-cdn.net/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/primitives.py:222:10) (sample)
[/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/primitives.py:47:8](https://file+.vscode-resource.vscode-cdn.net/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/primitives.py:47:8) (apply_stack)
[/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/handlers.py:745:8](https://file+.vscode-resource.vscode-cdn.net/home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/handlers.py:745:8) (seed.process_message)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError```
fehiepsi commented 4 months ago

Hi @amifalk, those autoguides are not designed to be composed with vmap after the construction because it needs initialization (to inspect the model and generate something like prototype_trace etc.). Something like this will work

def init(...):
    guide = AutoDelta(...)
    return guide.init(...)

init_state = jax.vmap(init)(...)
amifalk commented 4 months ago

I think there's still an issue here. When svi.init is called, it initializes both the model and the guide for the first time, which should set up the prototype trace. Batched model fitting works with AutoGuides if I vmap the SVI methods in all cases except when there is both a deterministic site and the guide is based on a blocked model.

The suggested approach yields the same error as before (though AutoGuides are not registered as pytrees so they cannot be returned after calling vmap).

def guide_init(rng_seed):
    guide = AutoDelta(block(seed(model, rng_seed=rng_seed), hide=['b']))
    seed(guide, rng_seed=rng_seed)()

    return 

keys = random.split(random.PRNGKey(0))
jax.vmap(guide_init)(keys) # this works

def guide_init_deterministic(rng_seed):
    guide = AutoDelta(block(seed(model_w_deterministic, rng_seed=rng_seed), hide=['b']))
    seed(guide, rng_seed=rng_seed)()

    return 

keys = random.split(random.PRNGKey(0))
jax.vmap(guide_init_deterministic)(keys) # tracer error
amifalk commented 4 months ago

@fehiepsi I've traced the source to this while loop. If I set _DISABLE_CONTROL_FLOW_PRIM = True, vmapping the svi.init method works. However, vmapping the guide initialization yields a new error in the while loop:

This BatchTracer with object id 140305711093520 was created on line:
  /home/ami/venvs/jax/lib/python3.11/site-packages/numpyro/infer/util.py:357:15 (find_valid_initial_params.<locals>.cond_fn)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
fehiepsi commented 4 months ago

If we use Python while loop, then the condition needs to be a Python value like True or False. having a jax object there won't work. What is your usage case by the way?

amifalk commented 4 months ago

I have a blocked model with a deterministic site that I'm trying to perform some simulation studies on. I want to see how variations in the structure of the dataset / model hyperparameters affect the performance, and I also want to be able to select the best result over multiple initializations. It's very slow to do this sequentially (for a small grid of hyperparams it took around 40 minutes), but after vmapping/pmapping with GPU I can get the entire grid to run in parallel. In my case it reduced the fitting time to 7 seconds.

Unfortunately, if I try to vmap the blocked model with deterministic sites present, it throws this error, so I have to instead recompute the deterministic sites at the end of model fitting.

In my case, I need to block the model to define an AutoGuide that is compatible with enumeration (blocking out the enumerated sites), but this would likely also be a problem for people using AutoGuideList.

fehiepsi commented 4 months ago

I think you can do something like

def run_svi(...):
    svi = ...
    svi_result = svi.run(...)
    return svi_result

svi_results = vmap(run_svi)(...)
amifalk commented 4 months ago

Unfortunately this still seems to throw the same tracer error.

def run_svi(key):
    optimizer = numpyro.optim.Adam(step_size=.01)
    guide = AutoDelta(block(seed(model, rng_seed=0), hide=['b']))
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

    return svi.run(key, 100, progress_bar=False)

def run_svi_deterministic(key):
    optimizer = numpyro.optim.Adam(step_size=.01)
    guide = AutoDelta(block(seed(model_w_deterministic, rng_seed=0), hide=['b']))
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

    return svi.run(key, 100, progress_bar=False)

keys = random.split(random.PRNGKey(0), 2)

jax.vmap(run_svi)(keys) # works
jax.vmap(run_svi_deterministic)(keys) # tracer error from the while loop in find_valid_initial_params
amifalk commented 4 months ago

Can we make it so that AutoGuides only collect non-enumerated model sample sites? This wouldn't fix the problem for all blocked models, but it would make collecting deterministic sites possible under batched svi for my use-case.

I think this would only have to be a one-liner change here-ish where we just ignore sample sites in the prototype trace that have site['infer'].get('enumerate') == True. That would also make the syntax for defining AutoGuides for enumerated models much simpler, e.g. just AutoGuide(model), instead of AutoGuide(block(seed(model, rng_seed=0), hide=["enumerated_site_1", "enumerated_site_2", ...]))

fehiepsi commented 4 months ago

Thanks @amifalk! There is indeed leakage here with the seed handler. I haven't been able to figure out why yet. Posting here for reference

import numpyro
import numpyro.distributions as dist
import jax

def model():
    return numpyro.sample('a', dist.Normal(0, 1))    

def run(key):
    return numpyro.infer.util.initialize_model(key, numpyro.handlers.seed(model, rng_seed=0))[0]

with jax.checking_leaks():
    jax.jit(run)(jax.random.PRNGKey(0))
amifalk commented 4 months ago

@fehiepsi With that example, I was able to narrow the source of the bug further - thanks! The while loop of _find_valid_params closes over the seeded model, but it also traces the model during its calls to potential_fn. I think the fact that the trace is seeing a rng key from the global call to seed is causing the error. Here's a minimal example.

import numpyro
import numpyro.distributions as dist
import jax

def model():
    return numpyro.sample('a', dist.Normal(0, 1))    

def run(key):
    seeded = numpyro.handlers.seed(model, rng_seed=0)

    def cond_fn(state):
        i, num = state
        return i < 10 

    def body_fn(state):
        i, num = state

        numpyro.handlers.trace(seeded).get_trace() # this references the global rng values in a jitted context
        # equivalently num = numpyro.handlers.trace(seeded).get_trace()['a']['value'] will raise an error        
        return (i + 1, num)

    return jax.lax.while_loop(cond_fn, body_fn, (0, 0))

with jax.checking_leaks():
    jax.jit(run)(jax.random.PRNGKey(0))

You can also verify this by replacing potential_fn in numpyro.infer.util.find_valid_initial_params with a placeholder that just returns a constant number.

fehiepsi commented 4 months ago

I think we figured it out. thanks for the examples!

seed(model) is an instance of a seed class which has mutable state. A fix for it is to close the seeded model into a function like

def seeded_model(*args, **kwargs):
    return seed(model, rng_seed=random.PRNGKey(0))(*args, **kwargs)

This way each time we call the model, a new instance of the seed handler will be created. Could you check if it works for your usage case? I'll think of a long term solution (maybe improve docstring for this).

amifalk commented 4 months ago

Yes, this fixed it! Not sure if there's any interest in adding to NumPyro, but here's the pattern for batching SVI: https://gist.github.com/amifalk/eb377a243b046105dc00beda79441b22