pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.13k stars 231 forks source link

Sample from distribution without storing #1695

Closed danjenson closed 1 month ago

danjenson commented 9 months ago

I am currently working on a project where we embed a VAE-decoder inside a model. Accordingly, we need to sample zs from a multivariate normal distribution, but we are not interested in the posterior of the zs. Here is an example model:

def model(y=None):
    var = numpyro.sample("variance", dist.HalfNormal())
    ls = numpyro.sample("lengthscale", dist.HalfNormal())
    z = numpyro.sample("z", dist.MultivariateNormal(jnp.zeros(2500), jnp.zeros(2500)))  # <- want to sample but not store
    y_hat = numpyro.deterministic("y_hat", vae.decode(jnp.array([*z, ls, var])))
    sigma = numpyro.sample("sigma", dist.HalfNormal(0.1))
    numpyro.sample("obs", dist.Normal(y_hat[mask], sigma), obs=y)

Currently, we are running inference on 50x50 grids with a z dimension of 2500 (one z per point in the grid), which means a standard model saves 2500 zs per step. We never use these zs and would like to prevent storing them to save memory and computation. We would greatly appreciate any advice!

fehiepsi commented 9 months ago

I don't think we store the latent values. Could you elaborate?

danjenson commented 8 months ago

Sorry for the late reply. I may be misunderstanding, but the problem is that we are sampling latent variables that are nuisance parameters, so we don't need estimates of their posteriors. Is using numpyro.sample still the correct construct for latent nuisance parameters or is there a lighter weight sampling procedure, e.g. a pure jax method that might be more appropriate?

fehiepsi commented 8 months ago

Are you using MCMC? There is collect_fields to filter out variables that are not required. If you are using SVI, then we don't store latent variables during training.

renecotyfanboy commented 8 months ago

I have the exact same issue ! @fehiepsi Could you elaborate on the use of collect_fields ? I can't find relevant entries in the docs

fehiepsi commented 8 months ago

See e.g. this comment https://forum.pyro.ai/t/reducing-mcmc-memory-usage/5639/4?u=fehiepsi

danjenson commented 8 months ago

I wasn't able to discern from that comment how to use collect_fields. It isn't an argument to NUTS(...), MCMC(...), or mcmc.run(..., collect_fields=...). Where / how do you add collect fields, and is it just a list of variable names you want to keep? I'm using numpyro==0.13.2. Thank you!

renecotyfanboy commented 8 months ago

If my understanding is correct, the only way is to run the MCMC step by step and manually trace the parameters of interest

fehiepsi commented 7 months ago

Sorry, my brain was not working when I sent the previous comment. The argument name is extra_fields, not collect_fields. There is a property named default_fields which will store the variables. I think we can enable an api to allow doing

mcmc = MCMC(NUTS(model))
mcmc.sampler.default_fields = ("z.foo", "z.bar")

following the changes in the forum comment (linked in my last comment).

danjenson commented 7 months ago

I'm still a bit lost on how I might do this for nested arrays. For instance, I am running the following model on an "image" of satellite data and I've got 60 subgrids of size 50x50. I sample a random vector z of 512 values for each subgrid for each sample. So, if I'm doing 1000 samples, that is 60 512 1000 values stored. However, I don't care about the posteriors of these values -- they are simply used to seed a generative model that I have inserted as a deterministic transformation (simulator.decode in the model). What is the best way to ignore the posteriors of the 60 * 512 z values?

    def satellite_model(T=None):
        sigma_T = numpyro.sample("sigma_T", dist.HalfNormal(10))
        for b in range(num_subgrids):
            z = numpyro.sample(f"z[{b}]", dist.Normal(0, 1).expand([z_dim]))
            ls = numpyro.sample(f"ls[{b}]", dist.Beta(3, 6))
            var = numpyro.sample(f"var[{b}]", dist.LogNormal(0, 1))
            c = jnp.hstack([ls, var])
            mu = simulator.apply(
                {"params": params}, z, c, method=simulator.decode
            ).squeeze()
            numpyro.sample(
                f"T[{b}]",
                dist.Normal(mu[non_nan_idx[b]], sigma_T),
                obs=T[b][non_nan_idx[b]],
            )
martinjankowiak commented 7 months ago

if you have a model with density p(x, y) and y is a "nuisance" variable in the sense that you don't care about it's posterior but you still want to integrate out the uncertainty associated with its unknown value it's still required to to do inference over y since different y slices of p(x, y) lead to different conditional posteriors over x and so there's no way around doing inference on y.

of course to save memory you needn't actually save all the y samples.

there's also another possibility in which you're not actually trying to do "proper inference" and maybe instead y is fixed once at the beginning or sampled from a fixed distribution at each step in inference---but that's not doing proper inference over x in the presence of uncertainty over y.

danjenson commented 7 months ago

We do need to do inference over z, especially since we are using HMC and it will be calculated gradients over z in the latent space, but we would prefer not to save all these samples to save memory. Is there a way to do this?

fehiepsi commented 7 months ago

If you don't want to change the source code, then you can do

import jax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

def model():
    numpyro.sample("x", dist.Normal(0, 1))
    numpyro.sample("y", dist.Normal(0, 1))

class CustomNUTS(NUTS):
    def postprocess_fn(self, args, kwargs):
        transform = super().postprocess_fn(args, kwargs)
        def new_transform(z):
            z = transform(z)
            z.pop("x")
            return z
        return new_transform

mcmc = MCMC(CustomNUTS(model), num_warmup=10, num_samples=20)
mcmc.run(jax.random.PRNGKey(0))
mcmc.get_samples().keys()

But it's easy to support this feature. As outlined above, we can:

Let's keep this issue open in case a contributor wants to support this feature. You can use the above CustomNUTS in the mean time.

amifalk commented 4 months ago

I find I often have a pattern where my random variable is a nuisance variable, but some deterministic function of it is meaningful. In this case, the desired behavior is more so a function of the model than a function of the inference algorithm, so it's inconvenient to have to tamper with settings for every fit.

I would much prefer to have a flag in the numpyro.sample function to toggle whether or not a site is collected during mcmc.

a = numpyro.sample('a_', dist.MultivariateNormal(jnp.zeros(2500), jnp.zeros(2500)), collect=False)
a = numpyro.deterministic('a', a*2) 

What do you think @fehiepsi ?

fehiepsi commented 4 months ago

Yes, we can add a field to the "infer" keyword. But this requires us to update all MCMC kernels. I feel that supporting mcmc.sampler.default_fields = ("z.a",) is simpler. What do you think?

amifalk commented 4 months ago

It looks like all the samplers create a trace on initialization, most via initialize_model. It should be easy to add a function to infer.util that takes the trace and returns the default fields. Even though we would have to update each one, I don't think it would add much complexity. We would just need to add a setter method for default_fields in the MCMCKernel superclass and add one line to each kernel.

def _init_state(self, ...):
   model_trace, ... = numpyro.infer.util.initialize_model(...)
   self.default_fields = numpyro.infer.util.get_default_fields(model_trace)

Is this solution ok for you? If so, I would be happy to draft up a PR.

fehiepsi commented 4 months ago

Yup, I think the solution looks good. Users can use either infer or default_fields.