Closed EdwardRaff closed 3 years ago
I realized why this doesn't work and is dumb on my part only after submitting the report... However, maybe it would help to add an assert to the constructor to warn about the max range of the mean parameter to help avoid dumb people like me.
I think we can either swap the lines
75 self._dirichlet = Dirichlet(
---> 76 jnp.stack([concentration1, concentration0], axis=-1)
77 )
78 super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
so that the validation for Beta is triggered before the Dirichlet. That way, you can be informed that Beta
distribution got invalid parameters. What do you think? If it is reasonable, do you want to submit the enhancement?
I'm still working on creating a minimum viable code to reproduce
I think this reproduces the issue:
import numpyro
import jax
import numpyro.distributions as dist
def model():
numpyro.sample('foo', dist.BetaProportion(1., 1.))
sampler = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(sampler, num_chains=1, num_samples=1, num_warmup=0)
mcmc.run(jax.random.PRNGKey(1))
I think we can either swap the lines [...] so that the validation for Beta is triggered before the Dirichlet.
i.e.,
diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py
index 7dfae50..343fbb3 100644
--- a/numpyro/distributions/continuous.py
+++ b/numpyro/distributions/continuous.py
@@ -72,10 +72,10 @@ class Beta(Distribution):
)
concentration1 = jnp.broadcast_to(concentration1, batch_shape)
concentration0 = jnp.broadcast_to(concentration0, batch_shape)
+ super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
self._dirichlet = Dirichlet(
jnp.stack([concentration1, concentration0], axis=-1)
)
- super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
?
I tried that, but still get the same error from the snippet above:
ValueError Traceback (most recent call last)
<ipython-input-1-1260359545ae> in <module>
8 sampler = numpyro.infer.NUTS(model)
9 mcmc = numpyro.infer.MCMC(sampler, num_chains=1, num_samples=1, num_warmup=0)
---> 10 mcmc.run(jax.random.PRNGKey(1))
~/numpyro-dev/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
596 map_args = (rng_key, init_state, init_params)
597 if self.num_chains == 1:
--> 598 states_flat, last_state = partial_map_fn(map_args)
599 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
600 else:
~/numpyro-dev/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
384 rng_key, init_state, init_params = init
385 if init_state is None:
--> 386 init_state = self.sampler.init(
387 rng_key,
388 self.num_warmup,
~/numpyro-dev/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
694 vmap(random.split)(rng_key), 0, 1
695 )
--> 696 init_params = self._init_state(
697 rng_key_init_model, model_args, model_kwargs, init_params
698 )
~/numpyro-dev/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
640 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
641 if self._model is not None:
--> 642 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
643 rng_key,
644 self._model,
~/numpyro-dev/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
654 with numpyro.validation_enabled(), trace() as tr:
655 # validate parameters
--> 656 substituted_model(*model_args, **model_kwargs)
657 # validate values
658 for site in tr.values():
~/numpyro-dev/numpyro/primitives.py in __call__(self, *args, **kwargs)
85 return self
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
89
~/numpyro-dev/numpyro/primitives.py in __call__(self, *args, **kwargs)
85 return self
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
89
<ipython-input-1-1260359545ae> in model()
4
5 def model():
----> 6 numpyro.sample('foo', dist.BetaProportion(1., 1.))
7
8 sampler = numpyro.infer.NUTS(model)
~/numpyro-dev/numpyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
92 if result is not None:
93 return result
---> 94 return super().__call__(*args, **kwargs)
95
96 @property
~/numpyro-dev/numpyro/distributions/continuous.py in __init__(self, mean, concentration, validate_args)
1574 concentration, lax.broadcast_shapes(jnp.shape(concentration))
1575 )
-> 1576 super().__init__(
1577 mean * concentration,
1578 (1.0 - mean) * concentration,
~/numpyro-dev/numpyro/distributions/continuous.py in __init__(self, concentration1, concentration0, validate_args)
74 concentration0 = jnp.broadcast_to(concentration0, batch_shape)
75 super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
---> 76 self._dirichlet = Dirichlet(
77 jnp.stack([concentration1, concentration0], axis=-1)
78 )
~/numpyro-dev/numpyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
92 if result is not None:
93 return result
---> 94 return super().__call__(*args, **kwargs)
95
96 @property
~/numpyro-dev/numpyro/distributions/continuous.py in __init__(self, concentration, validate_args)
154 self.concentration = concentration
155 batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
--> 156 super(Dirichlet, self).__init__(
157 batch_shape=batch_shape,
158 event_shape=event_shape,
~/numpyro-dev/numpyro/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
174 if not_jax_tracer(is_valid):
175 if not np.all(is_valid):
--> 176 raise ValueError(
177 "{} distribution got invalid {} parameter.".format(
178 self.__class__.__name__, param
ValueError: Dirichlet distribution got invalid concentration parameter.
I'll have another look later
Sorry, you are right. ~How about setting validate_args to False in the construction of _dirichlet
?~ I think we need self.mean = mean
in the construction of BeteProportion, so that it can be validated.
OK, got it - will open a PR soon
I'm still working on creating a minimum viable code to reproduce that doesn't require sharing my data, but when using the BetaProportion distribution I've obtained the below error: