pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.13k stars 232 forks source link

Sample an array of variable size #1526

Closed Sarenrenegade closed 1 year ago

Sarenrenegade commented 1 year ago

Hello! I have a naive question : is there a way to infer the size of an array and the elements present in this array? My objective is to recover past earthquakes on a fault given the concentration of 36Cl. To do so, I have to infer the number of earthquakes and then attribute an exhumation height for each earthquake. I also have a constraint on the elements of this array: the sum of the element cannot be superior to the fault scarp height. I have this code, with a much simpler forward function, that mimics what I need, but has the objective of finding ruptures :

""" Librairie de base """
import numpy as np
import matplotlib.pyplot as plt

""" Librairie utile pour l'inversion """
from jax import random, lax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import jax.numpy as jnp

def f(x, a):
    return a*x

def g(x, breach, slope):
    y = jnp.zeros((len(x)))

    for i in range(0, len(breach)):
        x_u = jnp.arange(int(jnp.sum(breach[0:i])), int(breach[i] + jnp.sum(breach[0:i])), dtype=float)
        y = y.at[int(jnp.sum(breach[0:i])):int(breach[i] + jnp.sum(breach[0:i]))].set(f(x_u, slope[i]))

    return y

ruptures = jnp.array([20, 20, 40, 20])
slope = jnp.array([2, 5, 10, 20])
x_values=jnp.arange(0, 100, dtype=float)
y_values=g(x_values, ruptures, slope)

plt.plot(x_values, y_values, '.')
plt.xlabel('abscisse')
plt.ylabel('ordonnée')
plt.show()

def inverse_ruptures(obs):
    nb_ruptures = numpyro.sample('nb_ruptures', dist.Uniform(0, 50))
    len_segment = jnp.zeros((int(nb_ruptures)))
    alpha_coeff = jnp.zeros((int(nb_ruptures)))

    for i in range(0, int(nb_ruptures)):

        len_segment = len_segment.at[int(i)].set(int(numpyro.sample('len_segment' + str(i)+str(0), dist.Uniform(0, 50))))
        n = 1
        while jnp.sum(len_segment)>=len(x_values):
            len_segment = len_segment.at[int(i)].set(int(numpyro.sample('len_segment' + str(i)+str(n), dist.Uniform(0, 50))))
            n=n+1
        alpha_coeff = alpha_coeff.at[int(i)].set(numpyro.sample('alpha_coeff' + str(i), dist.Uniform(0, 50)))
    print(len_segment)
    y_fit_numpyro = g(x_values, len_segment, alpha_coeff)
    return numpyro.sample('obs', dist.Normal(y_fit_numpyro, 0.5), obs=obs)

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
kernel = NUTS(inverse_ruptures)
mcmc = MCMC(kernel, num_warmup=500, num_samples=5000)
mcmc.run(rng_key, obs=y_values)
mcmc.print_summary()

posterior_samples = mcmc.get_samples()  # résultats sous forme de dictionnaire

Thanks for your help

fehiepsi commented 1 year ago

You can get the size of an array with np.size(array). Could you elaborate what you mean by inferring "the elements in this array"?

Could you move the question to our forum? https://forum.pyro.ai/ We only use github issues to track issues and feature requests.

Sarenrenegade commented 1 year ago

Thanks for the quick answer. The objective is to "retrieve" the size and every element inside this array, given a dataset. Basically I have a function f(a, x) = y. Both parameters x and a are arrays, and I want to find a, wich means finding the length of this array and every values inside it. I hope it is a bit clearer.

Sarenrenegade commented 1 year ago

Hi ! I found what I need on the notebooks you wrote : https://github.com/fehiepsi/rethinking-numpyro ! I will close this issue and leave the answer if anyone else should use numpyro as an "inversion library":

with numpyro.handlers.seed(rng_seed=2971):
        nb_ruptures = numpyro.sample('nb_ruptures', dist.Uniform(1, 10))
        a = numpyro.sample("a", dist.Normal(178, 20).expand([int(nb_ruptures)]))
        print(a, a.shape)

Thanks again