pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.5k stars 982 forks source link

FR ContinuousBinomial, ContinuousBetaBinomial distributions #2407

Closed fritzo closed 4 years ago

fritzo commented 4 years ago

This issue proposes two distributions that could enable HMC and reparameterized SVI inference for models with count-valued latents, such as epidemiological compartment models (see design doc).

Models such as SIR and SEIR typically use Binomial or overdispersed BetaBinomial distributions for transitions (or approximate these as Censored(Poission) or overdispersed Censored(NegativeBinomial)) and Binomial likelihoods (e.g. see this prototype). In these cases, the S,E,I variables are integers. This issue proposes to replace the integers total_count with positive real numbers denoting a binary mixture model over the floor and ceiling of total_count. The semantics of the transition and likelihood distributions differ in that transition will be real->real, but likelihood will be real->integer.

ContinuousBinomial likelihood

This is a simple two-component mixture model. Here's a sketch:

```python class ContinuousBinomial(TorchDistribution): arg_constraints = {"total_count": constraints.positive, "probs": constraints.positive} support = constraints.nonnegative_integer # Note lack of upper bound. ... def sample(self, sample_shape=torch.Size()): lb = self.total_count.floor() probs = lb + 1 - self.total_count bern = probs.bernoulli() total_count = (lb + bern).long() return dist.Binomial(self, total_count, self.probs) def log_prob(self, value): # The first part is a mixture model. lb = self.total_count.floor().long() total_count = torch.stack([lb, ub + 1], dim=-1) logits = torch.stack([ub - self.total_count, self.total_count - lb], dim=-1).log() log_prob = dist.BetaBinomial(self.concentration0.unsqeeze(-1), self.concentration1.unsqeeze(-1), total_count).log_prob(value.unsqueeze(-1)) # Unbounded support is required by actual applications. log_prob.masked_fill_(value.unsqueeze(-1) > ub, float(-inf)) return (logits + log_prob).logsumexp(dim=-1) # mixture ```

ContinuousBetaBinomial with reparametrized rsample

This uses a mixture model for density but also implements .rsample(), which is complex. Here's a sketch:

```python class ContinuousContinuousBetaBinomial(TorchDistribution): arg_constraints = {"total_count": constraints.positive, "concentration0": constraints.positive, "concentration1": constraints.positive} support = constraints.positive has_rsample = True ... def rsample(self, sample_shape=torch.Size()): probs = dist.Beta(...).rsample(sample_shape) lb = self.total_count.floor() lb_dist = dist.Binomial(lb, probs) lb_sample = lb_dist.sample() lb_rsample = # TODO use Dice? return probs + lb_rsample def log_prob(self, value): # TODO ```

Questions

Tasks

fritzo commented 4 years ago

These will benefit from https://github.com/pytorch/pytorch/issues/20343 and https://github.com/pytorch/pytorch/pull/31278

jamestwebber commented 4 years ago

These will benefit from pytorch/pytorch#20343 and pytorch/pytorch#31278

If it ever gets merged :(

fritzo commented 4 years ago

Closing in favor of an asymptotically exact method #2410