dirmeier / sbijax

Simulation-based inference in JAX
https://sbijax.rtfd.io
Apache License 2.0
21 stars 2 forks source link

Error in Permutation when using make_maf #41

Open Jice-Zeng opened 2 months ago

Jice-Zeng commented 2 months ago

Hi Simon, When I ran the following codes: n_dim_data = 2 n_layers, hidden_sizes = 5, (64, 64) neural_network = make_maf(n_dim_data,n_layers=5, n_layer_dimensions=[2, 2, 2, 2, 2],hidden_sizes=hidden_sizes) fns = prior_fn, simulator_fn model = NLE(fns, neuralnetwork) obs = jnp.array([-1.0, 1.0]) data, = model.simulate_data(jr.PRNGKey(0), n_simulations=10_000) params, losses = model.fit( jr.PRNGKey(1), data=data ) inference_results, diagnostics = model.sample_posterior( jr.PRNGKey(2), params, obs )

It appears an error: TypeError Traceback (most recent call last) Cell In[108], line 3 1 obs = jnp.array([-1.0, 1.0]) 2 data, _ = model.simulate_data(jr.PRNGKey(0), n_simulations=10_000) ----> 3 params, losses = model.fit( 4 jr.PRNGKey(1), data=data 5 ) 6 inference_results, diagnostics = model.sample_posterior( 7 jr.PRNGKey(2), params, obs 8 )

File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nle.py:87, in NLE.fit(self, rng_key, data, optimizer, n_iter, batch_size, percentage_data_as_validation_set, n_early_stopping_patience, **kwargs) 83 itr_key, rng_key = jr.split(rng_key) 84 train_iter, val_iter = self.as_iterators( 85 itr_key, data, batch_size, percentage_data_as_validation_set 86 ) ---> 87 params, losses = self._fit_model_single_round( 88 seed=rng_key, 89 train_iter=train_iter, 90 val_iter=val_iter, 91 optimizer=optimizer, 92 n_iter=n_iter, 93 n_early_stopping_patience=n_early_stopping_patience, 94 ) 96 return params, losses

File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nle.py:109, in NLE._fit_model_single_round(self, seed, train_iter, val_iter, optimizer, n_iter, n_early_stopping_patience) 99 def _fit_model_single_round( 100 self, 101 seed, (...) 106 n_early_stopping_patience, 107 ): 108 init_key, seed = jr.split(seed) --> 109 params = self._init_params(init_key, next(iter(train_iter))) 110 state = optimizer.init(params) 112 @jax.jit 113 def step(params, state, batch):

File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nle.py:176, in NLE._init_params(self, rng_key, init_data) 175 def _init_params(self, rng_key, init_data): --> 176 params = self.model.init( 177 rng_key, method="log_prob", y=init_data["y"], x=init_data["theta"] 178 ) 179 return params

File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/haiku/_src/transform.py:166, in without_state..init_fn(*args, kwargs) 165 def init_fn(*args, *kwargs) -> hk.MutableParams: --> 166 params, state = f.init(args, kwargs) 167 if state: 168 raise base.NonEmptyStateError( 169 "If your transformed function uses hk.{get,set}_state then use " 170 "hk.transform_with_state.")

File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/haiku/_src/transform.py:422, in transform_with_state..init_fn(rng, *args, *kwargs) 420 with base.new_context(rng=rng) as ctx: 421 try: --> 422 f(args, **kwargs) 423 except jax.errors.UnexpectedTracerError as e: 424 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e

File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nn/make_flow.py:137, in _make_maf.._flow(method, **kwargs) 132 raise ValueError( 133 f"n_dimension at layer {i} is layer than the dimension of" 134 f" the following layer {i + 1}" 135 ) 136 layers.append(layer) --> 137 layers.append(Permutation(order, 1)) 138 chain = Chain(layers[:-1]) 140 base_distribution = distrax.Independent( 141 distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)), 142 1, 143 )

File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/distrax/_src/utils/jittable.py:32, in Jittable.new(failed resolving arguments) 30 except ValueError: 31 registered_cls = cls # Already registered. ---> 32 return object.new(registered_cls)

TypeError: Can't instantiate abstract class Permutation without an implementation for abstract method 'forward_and_log_det'.

I guess the error is from Permutation function in Surjector library. When I checked the Permutation function, it does not include method 'forward_and_log_det'. Only '_forward_and_likelihood_contribution' and '_inverse_and_likelihood_contribution' are in Permutation. So I added two functions: 'forward_and_log_det' and 'inverse_and_log_det' in Permutation below: class Permutation(distrax.Bijector): """Permute the dimensions of a vector.

Args:
    permutation: a vector of integer indexes representing the order of
        the elements
    event_ndims_in: number of input event dimensions

Examples:
    >>> from surjectors import Permutation
    >>> from jax import numpy as jnp
    >>>
    >>> order = jnp.arange(10)
    >>> perm = Permutation(order, 1)
"""

def __init__(self, permutation, event_ndims_in: int):
    super().__init__(event_ndims_in)
    self.permutation = permutation

def forward_and_log_det(self, x):
    # Forward transformation and log determinant calculation
    z, log_det = self._forward_and_likelihood_contribution(x)
    return z, log_det

def inverse_and_log_det(self, y):
    # Inverse transformation and log determinant calculation
    z, log_det = self._inverse_and_likelihood_contribution(y)
    return z, log_det

def _forward_and_likelihood_contribution(self, z):
    return z[..., self.permutation], jnp.full(jnp.shape(z)[:-1], 0.0)

def _inverse_and_likelihood_contribution(self, y):
    size = self.permutation.size
    permutation_inv = (
        jnp.zeros(size, dtype=jnp.result_type(int))
        .at[self.permutation]
        .set(jnp.arange(size))
    )
    return y[..., permutation_inv], jnp.full(jnp.shape(y)[:-1], 0.0)

Right now, with the new permutation function, the use of make_maf works. Could you please check whether my implementation is right or wrong? Thanks!

dirmeier commented 2 months ago

Thanks for reporting. Ill fix this as soon as my time allows. As I was saying in the other thread, we did a major refactor where I likely introduced bugs.