File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/haiku/_src/transform.py:166, in without_state..init_fn(*args, kwargs)
165 def init_fn(*args, *kwargs) -> hk.MutableParams:
--> 166 params, state = f.init(args, kwargs)
167 if state:
168 raise base.NonEmptyStateError(
169 "If your transformed function uses hk.{get,set}_state then use "
170 "hk.transform_with_state.")
File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/haiku/_src/transform.py:422, in transform_with_state..init_fn(rng, *args, *kwargs)
420 with base.new_context(rng=rng) as ctx:
421 try:
--> 422 f(args, **kwargs)
423 except jax.errors.UnexpectedTracerError as e:
424 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e
File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nn/make_flow.py:137, in _make_maf.._flow(method, **kwargs)
132 raise ValueError(
133 f"n_dimension at layer {i} is layer than the dimension of"
134 f" the following layer {i + 1}"
135 )
136 layers.append(layer)
--> 137 layers.append(Permutation(order, 1))
138 chain = Chain(layers[:-1])
140 base_distribution = distrax.Independent(
141 distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)),
142 1,
143 )
TypeError: Can't instantiate abstract class Permutation without an implementation for abstract method 'forward_and_log_det'.
I guess the error is from Permutation function in Surjector library. When I checked the Permutation function, it does not include method 'forward_and_log_det'. Only '_forward_and_likelihood_contribution' and '_inverse_and_likelihood_contribution' are in Permutation. So I added two functions: 'forward_and_log_det' and 'inverse_and_log_det' in Permutation below:
class Permutation(distrax.Bijector):
"""Permute the dimensions of a vector.
Args:
permutation: a vector of integer indexes representing the order of
the elements
event_ndims_in: number of input event dimensions
Examples:
>>> from surjectors import Permutation
>>> from jax import numpy as jnp
>>>
>>> order = jnp.arange(10)
>>> perm = Permutation(order, 1)
"""
def __init__(self, permutation, event_ndims_in: int):
super().__init__(event_ndims_in)
self.permutation = permutation
def forward_and_log_det(self, x):
# Forward transformation and log determinant calculation
z, log_det = self._forward_and_likelihood_contribution(x)
return z, log_det
def inverse_and_log_det(self, y):
# Inverse transformation and log determinant calculation
z, log_det = self._inverse_and_likelihood_contribution(y)
return z, log_det
def _forward_and_likelihood_contribution(self, z):
return z[..., self.permutation], jnp.full(jnp.shape(z)[:-1], 0.0)
def _inverse_and_likelihood_contribution(self, y):
size = self.permutation.size
permutation_inv = (
jnp.zeros(size, dtype=jnp.result_type(int))
.at[self.permutation]
.set(jnp.arange(size))
)
return y[..., permutation_inv], jnp.full(jnp.shape(y)[:-1], 0.0)
Right now, with the new permutation function, the use of make_maf works. Could you please check whether my implementation is right or wrong?
Thanks!
Thanks for reporting. Ill fix this as soon as my time allows. As I was saying in the other thread, we did a major refactor where I likely introduced bugs.
Hi Simon, When I ran the following codes: n_dim_data = 2 n_layers, hidden_sizes = 5, (64, 64) neural_network = make_maf(n_dim_data,n_layers=5, n_layer_dimensions=[2, 2, 2, 2, 2],hidden_sizes=hidden_sizes) fns = prior_fn, simulator_fn model = NLE(fns, neuralnetwork) obs = jnp.array([-1.0, 1.0]) data, = model.simulate_data(jr.PRNGKey(0), n_simulations=10_000) params, losses = model.fit( jr.PRNGKey(1), data=data ) inference_results, diagnostics = model.sample_posterior( jr.PRNGKey(2), params, obs )
It appears an error: TypeError Traceback (most recent call last) Cell In[108], line 3 1 obs = jnp.array([-1.0, 1.0]) 2 data, _ = model.simulate_data(jr.PRNGKey(0), n_simulations=10_000) ----> 3 params, losses = model.fit( 4 jr.PRNGKey(1), data=data 5 ) 6 inference_results, diagnostics = model.sample_posterior( 7 jr.PRNGKey(2), params, obs 8 )
File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nle.py:87, in NLE.fit(self, rng_key, data, optimizer, n_iter, batch_size, percentage_data_as_validation_set, n_early_stopping_patience, **kwargs) 83 itr_key, rng_key = jr.split(rng_key) 84 train_iter, val_iter = self.as_iterators( 85 itr_key, data, batch_size, percentage_data_as_validation_set 86 ) ---> 87 params, losses = self._fit_model_single_round( 88 seed=rng_key, 89 train_iter=train_iter, 90 val_iter=val_iter, 91 optimizer=optimizer, 92 n_iter=n_iter, 93 n_early_stopping_patience=n_early_stopping_patience, 94 ) 96 return params, losses
File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nle.py:109, in NLE._fit_model_single_round(self, seed, train_iter, val_iter, optimizer, n_iter, n_early_stopping_patience) 99 def _fit_model_single_round( 100 self, 101 seed, (...) 106 n_early_stopping_patience, 107 ): 108 init_key, seed = jr.split(seed) --> 109 params = self._init_params(init_key, next(iter(train_iter))) 110 state = optimizer.init(params) 112 @jax.jit 113 def step(params, state, batch):
File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nle.py:176, in NLE._init_params(self, rng_key, init_data) 175 def _init_params(self, rng_key, init_data): --> 176 params = self.model.init( 177 rng_key, method="log_prob", y=init_data["y"], x=init_data["theta"] 178 ) 179 return params
File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/haiku/_src/transform.py:166, in without_state..init_fn(*args, kwargs)
165 def init_fn(*args, *kwargs) -> hk.MutableParams:
--> 166 params, state = f.init(args, kwargs)
167 if state:
168 raise base.NonEmptyStateError(
169 "If your transformed function uses
hk.{get,set}_state
then use " 170 "hk.transform_with_state
.")File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/haiku/_src/transform.py:422, in transform_with_state..init_fn(rng, *args, *kwargs)
420 with base.new_context(rng=rng) as ctx:
421 try:
--> 422 f(args, **kwargs)
423 except jax.errors.UnexpectedTracerError as e:
424 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e
File ~/Library/CloudStorage/OneDrive-PNNL/Desktop/projects/Surjective/sbijax/_src/nn/make_flow.py:137, in _make_maf.._flow(method, **kwargs)
132 raise ValueError(
133 f"n_dimension at layer {i} is layer than the dimension of"
134 f" the following layer {i + 1}"
135 )
136 layers.append(layer)
--> 137 layers.append(Permutation(order, 1))
138 chain = Chain(layers[:-1])
140 base_distribution = distrax.Independent(
141 distrax.Normal(jnp.zeros(n_dimension), jnp.ones(n_dimension)),
142 1,
143 )
File ~/anaconda3/envs/surjection/lib/python3.12/site-packages/distrax/_src/utils/jittable.py:32, in Jittable.new(failed resolving arguments) 30 except ValueError: 31 registered_cls = cls # Already registered. ---> 32 return object.new(registered_cls)
TypeError: Can't instantiate abstract class Permutation without an implementation for abstract method 'forward_and_log_det'.
I guess the error is from Permutation function in Surjector library. When I checked the Permutation function, it does not include method 'forward_and_log_det'. Only '_forward_and_likelihood_contribution' and '_inverse_and_likelihood_contribution' are in Permutation. So I added two functions: 'forward_and_log_det' and 'inverse_and_log_det' in Permutation below: class Permutation(distrax.Bijector): """Permute the dimensions of a vector.
Right now, with the new permutation function, the use of make_maf works. Could you please check whether my implementation is right or wrong? Thanks!