pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.32k stars 246 forks source link

Bug: Dirichlet distribution got invalid concentration parameter for BetaDistribution #1206

Closed EdwardRaff closed 3 years ago

EdwardRaff commented 3 years ago

I'm still working on creating a minimum viable code to reproduce that doesn't require sharing my data, but when using the BetaProportion distribution I've obtained the below error:

ValueError                                Traceback (most recent call last)
<ipython-input-51-4f5c4ea19e62> in <module>()
     22     progress_bar=True,
     23   )
---> 24   mcmc.run(rng_key, 7, reproduced, year_published, code_available, theory, empirical, tfpg, types, topic, citation_count)
     25 

/usr/local/lib/python3.7/dist-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    570         map_args = (rng_key, init_state, init_params)
    571         if self.num_chains == 1:
--> 572             states_flat, last_state = partial_map_fn(map_args)
    573             states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    574         else:

/usr/local/lib/python3.7/dist-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    363                 init_params,
    364                 model_args=args,
--> 365                 model_kwargs=kwargs,
    366             )
    367         sample_fn, postprocess_fn = self._get_cached_fns()

/usr/local/lib/python3.7/dist-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    695             )
    696         init_params = self._init_state(
--> 697             rng_key_init_model, model_args, model_kwargs, init_params
    698         )
    699         if self._potential_fn and init_params is None:

/usr/local/lib/python3.7/dist-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
    647                 model_args=model_args,
    648                 model_kwargs=model_kwargs,
--> 649                 forward_mode_differentiation=self._forward_mode_differentiation,
    650             )
    651             if self._init_fn is None:

/usr/local/lib/python3.7/dist-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    639             with numpyro.validation_enabled(), trace() as tr:
    640                 # validate parameters
--> 641                 substituted_model(*model_args, **model_kwargs)
    642                 # validate values
    643                 for site in tr.values():

/usr/local/lib/python3.7/dist-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     85             return self
     86         with self:
---> 87             return self.fn(*args, **kwargs)
     88 
     89 

/usr/local/lib/python3.7/dist-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     85             return self
     86         with self:
---> 87             return self.fn(*args, **kwargs)
     88 
     89 

<ipython-input-49-e65da1086e6f> in model(max_clusters, reproduced, year_published, code_available, theory, empirical, tfpg, types, topic, citation_count)
     19   with numpyro.plate('All Years', max_year-min_year):
     20     with numpyro.plate('Topic Area', n_topics):# shape: (n_topics, max_year-min_year)
---> 21       topic_year_m = sample("Topic-Year Influence", dist.BetaProportion(mean=1, concentration=1))
     22 
     23   print(topic_year_m.shape)

/usr/local/lib/python3.7/dist-packages/numpyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
     92             if result is not None:
     93                 return result
---> 94         return super().__call__(*args, **kwargs)
     95 
     96     @property

/usr/local/lib/python3.7/dist-packages/numpyro/distributions/continuous.py in __init__(self, mean, concentration, validate_args)
   1579             mean * concentration,
   1580             (1.0 - mean) * concentration,
-> 1581             validate_args=validate_args,
   1582         )

/usr/local/lib/python3.7/dist-packages/numpyro/distributions/continuous.py in __init__(self, concentration1, concentration0, validate_args)
     74         concentration0 = jnp.broadcast_to(concentration0, batch_shape)
     75         self._dirichlet = Dirichlet(
---> 76             jnp.stack([concentration1, concentration0], axis=-1)
     77         )
     78         super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

/usr/local/lib/python3.7/dist-packages/numpyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
     92             if result is not None:
     93                 return result
---> 94         return super().__call__(*args, **kwargs)
     95 
     96     @property

/usr/local/lib/python3.7/dist-packages/numpyro/distributions/continuous.py in __init__(self, concentration, validate_args)
    157             batch_shape=batch_shape,
    158             event_shape=event_shape,
--> 159             validate_args=validate_args,
    160         )
    161 

/usr/local/lib/python3.7/dist-packages/numpyro/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
    176                         raise ValueError(
    177                             "{} distribution got invalid {} parameter.".format(
--> 178                                 self.__class__.__name__, param
    179                             )
    180                         )

ValueError: Dirichlet distribution got invalid concentration parameter.
EdwardRaff commented 3 years ago

I realized why this doesn't work and is dumb on my part only after submitting the report... However, maybe it would help to add an assert to the constructor to warn about the max range of the mean parameter to help avoid dumb people like me.

fehiepsi commented 3 years ago

I think we can either swap the lines

     75         self._dirichlet = Dirichlet(
---> 76             jnp.stack([concentration1, concentration0], axis=-1)
     77         )
     78         super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

so that the validation for Beta is triggered before the Dirichlet. That way, you can be informed that Beta distribution got invalid parameters. What do you think? If it is reasonable, do you want to submit the enhancement?

MarcoGorelli commented 3 years ago

I'm still working on creating a minimum viable code to reproduce

I think this reproduces the issue:

import numpyro
import jax
import numpyro.distributions as dist

def model():
    numpyro.sample('foo', dist.BetaProportion(1., 1.))

sampler = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(sampler, num_chains=1, num_samples=1, num_warmup=0)
mcmc.run(jax.random.PRNGKey(1))
MarcoGorelli commented 3 years ago

I think we can either swap the lines [...] so that the validation for Beta is triggered before the Dirichlet.

i.e.,

diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py
index 7dfae50..343fbb3 100644
--- a/numpyro/distributions/continuous.py
+++ b/numpyro/distributions/continuous.py
@@ -72,10 +72,10 @@ class Beta(Distribution):
         )
         concentration1 = jnp.broadcast_to(concentration1, batch_shape)
         concentration0 = jnp.broadcast_to(concentration0, batch_shape)
+        super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
         self._dirichlet = Dirichlet(
             jnp.stack([concentration1, concentration0], axis=-1)
         )
-        super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

     def sample(self, key, sample_shape=()):
         assert is_prng_key(key)

?

I tried that, but still get the same error from the snippet above:

ValueError                                Traceback (most recent call last)
<ipython-input-1-1260359545ae> in <module>
      8 sampler = numpyro.infer.NUTS(model)
      9 mcmc = numpyro.infer.MCMC(sampler, num_chains=1, num_samples=1, num_warmup=0)
---> 10 mcmc.run(jax.random.PRNGKey(1))

~/numpyro-dev/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    596         map_args = (rng_key, init_state, init_params)
    597         if self.num_chains == 1:
--> 598             states_flat, last_state = partial_map_fn(map_args)
    599             states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    600         else:

~/numpyro-dev/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    384         rng_key, init_state, init_params = init
    385         if init_state is None:
--> 386             init_state = self.sampler.init(
    387                 rng_key,
    388                 self.num_warmup,

~/numpyro-dev/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    694                 vmap(random.split)(rng_key), 0, 1
    695             )
--> 696         init_params = self._init_state(
    697             rng_key_init_model, model_args, model_kwargs, init_params
    698         )

~/numpyro-dev/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
    640     def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    641         if self._model is not None:
--> 642             init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
    643                 rng_key,
    644                 self._model,

~/numpyro-dev/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    654             with numpyro.validation_enabled(), trace() as tr:
    655                 # validate parameters
--> 656                 substituted_model(*model_args, **model_kwargs)
    657                 # validate values
    658                 for site in tr.values():

~/numpyro-dev/numpyro/primitives.py in __call__(self, *args, **kwargs)
     85             return self
     86         with self:
---> 87             return self.fn(*args, **kwargs)
     88 
     89 

~/numpyro-dev/numpyro/primitives.py in __call__(self, *args, **kwargs)
     85             return self
     86         with self:
---> 87             return self.fn(*args, **kwargs)
     88 
     89 

<ipython-input-1-1260359545ae> in model()
      4 
      5 def model():
----> 6     numpyro.sample('foo', dist.BetaProportion(1., 1.))
      7 
      8 sampler = numpyro.infer.NUTS(model)

~/numpyro-dev/numpyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
     92             if result is not None:
     93                 return result
---> 94         return super().__call__(*args, **kwargs)
     95 
     96     @property

~/numpyro-dev/numpyro/distributions/continuous.py in __init__(self, mean, concentration, validate_args)
   1574             concentration, lax.broadcast_shapes(jnp.shape(concentration))
   1575         )
-> 1576         super().__init__(
   1577             mean * concentration,
   1578             (1.0 - mean) * concentration,

~/numpyro-dev/numpyro/distributions/continuous.py in __init__(self, concentration1, concentration0, validate_args)
     74         concentration0 = jnp.broadcast_to(concentration0, batch_shape)
     75         super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
---> 76         self._dirichlet = Dirichlet(
     77             jnp.stack([concentration1, concentration0], axis=-1)
     78         )

~/numpyro-dev/numpyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
     92             if result is not None:
     93                 return result
---> 94         return super().__call__(*args, **kwargs)
     95 
     96     @property

~/numpyro-dev/numpyro/distributions/continuous.py in __init__(self, concentration, validate_args)
    154         self.concentration = concentration
    155         batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
--> 156         super(Dirichlet, self).__init__(
    157             batch_shape=batch_shape,
    158             event_shape=event_shape,

~/numpyro-dev/numpyro/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
    174                 if not_jax_tracer(is_valid):
    175                     if not np.all(is_valid):
--> 176                         raise ValueError(
    177                             "{} distribution got invalid {} parameter.".format(
    178                                 self.__class__.__name__, param

ValueError: Dirichlet distribution got invalid concentration parameter.

I'll have another look later

fehiepsi commented 3 years ago

Sorry, you are right. ~How about setting validate_args to False in the construction of _dirichlet?~ I think we need self.mean = mean in the construction of BeteProportion, so that it can be validated.

MarcoGorelli commented 3 years ago

OK, got it - will open a PR soon