pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.12k stars 231 forks source link

possible to use numpyro as a modeling language only? #546

Closed justindomke closed 4 years ago

justindomke commented 4 years ago

Hi! I'm very interested in this project. However, my potential use is a bit different from what seems typical with numpyro, and I was wondering if there was a "recommended" way to go about it.

In short, what I'd like to do is use numpyro as a modeling language only, and implement my own inference algorithms. Essentially, I'd like to define a model in numpyro, then specify values for some subset of the variables in the model, then get a log_prob(*vars) function that would evaluate the log probability, given a configuration of all of the other variables. That's it! I'd like to do everything else myself.

I'm aware of the log_density function in the utilities section, but it is challenging to figure out how this might be used. If there were any examples, that would be helpful.

If I'm being greedy, even better would be a function log_prob(*unconstrained_vars) that evaluated the probability after all the variables were transformed to an unconstrained space, along with a constrain(*unconstrained_vars) function that would transform them back.

I appreciate any help!

fehiepsi commented 4 years ago

I think all of those (log_prob and constrain) are exactly what initialize_model does for you. You can find its usage in this example.

justindomke commented 4 years ago

Thank you! Indeed, that seems to do it! In case this is of interest to anyone else, here's what I think is a minimal working example of this:

import numpyro
import numpyro.distributions as dist
import jax
import jax.random

def model(c):
    a = numpyro.sample('a', dist.Normal(0.0, 1.0))
    b = numpyro.sample('b', dist.Gamma(1.0))
    numpyro.sample('c', dist.Normal(a, b), obs=c)

c = 1.0

rng_key = jax.random.PRNGKey(0)
init_params, potential_fn, constrain_fn = numpyro.infer.util.initialize_model(rng_key, model, model_args=(c,))

print('init_params:  ', init_params)
print('constrain_fn: ', constrain_fn(init_params))
print('potential_fn: ', potential_fn(init_params))
print('grad:         ', jax.grad(potential_fn)(init_params))

Which leads to the results of

init_params:   {'a': DeviceArray(1.9340096, dtype=float32), 'b': DeviceArray(0.8950882, dtype=float32)}
constrain_fn:  {'a': DeviceArray(1.9340096, dtype=float32), 'b': DeviceArray(2.4475517, dtype=float32)}
potential_fn:  6.228439
grad:          {'a': DeviceArray(2.0899243, dtype=float32), 'b': DeviceArray(2.301926, dtype=float32)}
justindomke commented 4 years ago

Actually, could I be even more greedy?

Is there any way to create a function that's like potential_fn except that it also takes an rng_key argument and performs data subsampling? (Probably subsampling only over plates as determined by subsample_size arguments?

As an example of something that doesn't work, I tried the following:

import numpyro
import numpyro.distributions as dist
import jax
import jax.random

def model(c):
    a = numpyro.sample('a', dist.Normal(0.0, 1.0))
    b = numpyro.sample('b', dist.Gamma(1.0))
    with numpyro.plate('N',2,subsample_size=1):
        numpyro.sample('c', dist.Normal(a, b), obs=c)

c = jax.numpy.array([1.0,2.0])

rng_key = jax.random.PRNGKey(0)
init_params, potential_fn, constrain_fn = numpyro.infer.util.initialize_model(rng_key, model, model_args=(c,))

print('init_params:  ', init_params)
print('constrain_fn: ', constrain_fn(init_params))
print('potential_fn: ', potential_fn(init_params))
print('grad:         ', jax.grad(potential_fn)(init_params))

My hope was that a random c would be used in each call to potential_fn, but that doesn't seem to be what is happening. Is this the correct usage of plate?

fehiepsi commented 4 years ago

Looking like get_potential_fn with dynamic_args=True is what you are looking for:

potential_fn_gen, _ = get_potential_fn(PRNGKey(0), model, dynamic_args=True, ...)
def potential_fn(rng, params):
    # use rng to get subsample of c
    return potential_fn_gen(c_subsample)(params)
justindomke commented 4 years ago

Thanks so much for your help! Just in case anyone else would like to see a full example of this, here's a version that uses inialize_model with 2 data without any data subsampling

import numpyro
import numpyro.distributions as dist
import jax
import jax.random

def model(c):
    a = numpyro.sample('a', dist.Normal(0.0, 1.0))
    b = numpyro.sample('b', dist.Gamma(1.0))
    with numpyro.plate('N',2):
        numpyro.sample('c', dist.Normal(a, b), obs=c)

c = jax.numpy.array([1.0,2.0])

rng_key = jax.random.PRNGKey(0)
init_params, potential_fn, constrain_fn = numpyro.infer.util.initialize_model(rng_key, model, model_args=(c,))

print('init_params:  ', init_params)
print('constrain_fn: ', constrain_fn(init_params))
print('potential_fn: ', potential_fn(init_params))
print('grad:         ', jax.grad(potential_fn)(init_params))

The results are:

init_params:   {'a': DeviceArray(1.9340096, dtype=float32), 'b': DeviceArray(0.8950882, dtype=float32)}
constrain_fn:  {'a': DeviceArray(1.9340096, dtype=float32), 'b': DeviceArray(2.4475517, dtype=float32)}
potential_fn:  8.0428295
grad:          {'a': DeviceArray(2.0789087, dtype=float32), 'b': DeviceArray(3.301199, dtype=float32)}

Now, here's a version that uses get_potential_fn with subsampling:

import numpyro
import numpyro.distributions as dist
import jax
import jax.random

def model(c):
    a = numpyro.sample('a', dist.Normal(0.0, 1.0))
    b = numpyro.sample('b', dist.Gamma(1.0))
    with numpyro.plate('N',2,subsample_size=1):
        numpyro.sample('c', dist.Normal(a, b), obs=c)

c = jax.numpy.array([1.0,2.0])

rng_key = jax.random.PRNGKey(1)
potential_fn, constrain_fn = numpyro.infer.util.get_potential_fn(rng_key, model, dynamic_args=True, model_args=(c,))
init_params = {'a':jax.numpy.array(1.9340096), 'b':jax.numpy.array(0.8950882)}

print('init_params:  ', init_params)
print('constrain_fn: ', constrain_fn(c[0])(init_params))
print('potential_fn: ', potential_fn(c[0])(init_params))
print('potential_fn: ', potential_fn(c[1])(init_params))
print('grad[0]:      ', jax.grad(lambda params : potential_fn(c[0])(params))(init_params))
print('grad[1]:      ', jax.grad(lambda params : potential_fn(c[1])(params))(init_params))

With the results of

init_params:   {'a': DeviceArray(1.9340096, dtype=float32), 'b': DeviceArray(0.8950882, dtype=float32)}
constrain_fn:  {'a': DeviceArray(1.9340096, dtype=float32), 'b': DeviceArray(2.4475517, dtype=float32)}
potential_fn:  8.115278
potential_fn:  7.970379
grad[0]:       {'a': DeviceArray(2.2458394, dtype=float32), 'b': DeviceArray(3.1563, dtype=float32)}
grad[1]:       {'a': DeviceArray(1.9119779, dtype=float32), 'b': DeviceArray(3.4460979, dtype=float32)}

Since (8.115278+7.970379)/2=8.0428285 this is exactly what we want. You can also check that the gradients average out in the correct way.

fehiepsi commented 4 years ago

@justindomke We have a forum where those tips would be more accessible to users. Thanks for your clear report! We didn't focus much on SVI, hence subsampling, right now so it is awesome that the MCMC utilities work out of the box for you. :D