NIFTy-PPL / NIFTy

Probabilistic programming framework for signal inference algorithms that operate regardless of the underlying grids and their resolutions
https://ift.pages.mpcdf.de/nifty/index.html
GNU General Public License v3.0
17 stars 2 forks source link

Create method for drawing synthetic data from a model #8

Open Edenhofer opened 6 months ago

Edenhofer commented 6 months ago

One recurring feature that everybody tends to implement on their own is the drawing of synthetic data from a likelihood. We could simplify the process and save the user some time if we were to provide a method for the likelihood that is able to draw synthetic data realizations from it.

homerjed commented 5 months ago

Hi Gordian,

I'm not sure if this is the same issue, but I have a sampler that converts latents xi into signals. I ran this problem-setup a year or so ago with the old version of nifty (non-JAX) and it ran about the same speed without GPUs.

signal = jft.WrappedCall(
    sample_fn, 
    name="xi", 
    shape=xi_shape, 
    dtype=jnp.float64,
    white_init=True
)

I also get out-of-memory issues on my GPU when I attempt to use too many samples for KL-optimising for MGVI. I can only use 1 sample per iteration, where before I was using ~32. My xi variables are also ~20 times lower dimensional also when using jft (=nifty.re).

I think I saw that my sampler was being re-JIT-compiled repeatedly, I tried turning off kl_jit and residual_jit but still I get the memory issues.

Is this a similar problem? Do you have any insight into what I may be doing wrong? Thanks!

P.S. this is part of what seems to go wrong when I use "nonlinear_resample"...

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.
Edenhofer commented 5 months ago

Thanks for your interest in NIFTy :)

No, this is not the same issue. Please file a new issue for the performance problem you are encountering with a minimal reproducing example. Based on the limited code you provided, it looks like JAX is not happy with your code because it contains large constants. The following might be helpful for further reference: common gotchas in JAX#control-flow.