pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.12k stars 231 forks source link

Constraints for multiple intervals #1829

Closed Qazalbash closed 1 month ago

Qazalbash commented 2 months ago

I have implemented three different constraints that involve multiple closed intervals.

$$ x \in \bigcup_{i=1}^{n} [a_i, bi] \implies \bigvee{i=1}^{n} (x \in [a_i, b_i]) $$

$$ x \in \bigcap_{i=1}^{n} [a_i, bi] \implies \bigwedge{i=1}^{n} (x \in [a_i, b_i]) $$

$$ \bigwedge_{i=1}^{n} (x_i \in [a_i, b_i]) $$

fehiepsi commented 1 month ago

Hi @Qazalbash, do you have any concrete example of those constraints?

Qazalbash commented 1 month ago

For the union of intervals, an example. I now realize we can avoid the intersection of intervals to a single interval. For the unique interval case, I was writing transformation defined as,

$$ f:(x,y)\to (x',y') $$

where, $0 < y \leq x$ and,

$$ x'=\frac{(xy)^{3/5}}{(x+y)^{1/5}},\qquad y'=\frac{xy}{(x+y)^{2}} $$

and turns out the bounds for each are,

$$ 0 < x'<\infty \qquad 0 < y' \leq \frac{1}{4} $$

I had to deal with a similar situation where each element in the event_dim=1 represents a quantity and has its bounds and I had to write a separate constraint for them.

fehiepsi commented 1 month ago

For the unique interval case, you can use independent(interval(lower_bound, upper_bound)). You are right that the intersection case is just another interval. The trickiest one is the union of intervals, which I think you can assume they do not overlap and maps interval i to (i/n,i/n+1/n) where n is the number of intervals. I guess it is better to avoid dealing with such complicated constraint.

Qazalbash commented 1 month ago

I got your point. I still don't get independent(interval(lower_bound, upper_bound)). Can you explain it using my mentioned use case? After that, I will close the PR!

fehiepsi commented 1 month ago

In that example, I guess you can use the constraint

constraints.independent(constraints.interval(np.zeros(2), np.array([1e6, 0.25])))
Qazalbash commented 1 month ago

Thank you!