JuliaStats / Distributions.jl

A Julia package for probability distributions and associated functions.
Other
1.1k stars 414 forks source link

Degenerate Distribution interface is probably necessary #1880

Open quildtide opened 1 month ago

quildtide commented 1 month ago

Long rambling introduction

Distributions.jl currently allows many weird edge-cases which result in Dirac-type degenerate point mass distributions.

There are numerous bugs related to many of these cases, e.g.

julia> extrema(Bernoulli(1.0))
(false, true)

julia> extrema(Bernoulli(0.0))
(false, true)

julia> extrema(Beta(1, Inf))
(0.0, 1.0)

Some time ago, I looked into fixing the edge cases where Beta can turn degenerate. The problem is that many methods which were previously arbitrarily simple (like mean(d::Beta) = ((α, β) = params(d); α / (α + β))) become much longer, like:

function mean(d::Beta{T})::float(T) where T
    (α, β) = params(d);
    if isinf(α)
        return isinf(β) ? 0.5 : 1.0
    end
    return α / (α + β)
end

Even worse is when methods with currently-constant outputs stop being constants as a result. For example, minimum(Normal(0,0)) returns -Inf; it dispatches to a method defined by a macro, @distr_support Normal -Inf Inf. Effectively, minimum(::Normal) always returns -Inf, and the compiler is able to make all sorts of optimizations because of this.

For example, compare:

julia> @code_lowered maximum(Bernoulli(0.0))
CodeInfo(
1 ─     return true
)

julia> @code_lowered maximum(Binomial(1, 0.0))
CodeInfo(
1 ─ %1 = Base.getproperty(d, :n)
└──      return %1
)

Fixing these methods to recognize that Bernoulli(0.0) and Binomial(1, 0.0) have a maximum at 0 would prevent these kinds of compiler optimizations.

The Problem

Distributions.jl currently "supports" a massive amount of edge-cases which result in degenerate distributions. Many methods return incorrect results on these degenerate cases, BUT fixing these rare edge cases would have tangible performance impact on normal usage for people who aren't leveraging the degenerate distributions.

For some functions like rand, fixing edge-cases so rand(d::Normal) and friends don't cause infinite loops seems like a totally reasonable thing to do, since rand is already such a costly function. For smaller functions like minimum, I think there is a legitimate discussion to be had.

Potential Solutions

Honestly, I have no clue. I would have submitted a pull req first if I had any strong opinion on how these issues should be resolved.

There is a part of me that thinks that most Distributions should stop supporting their degenerate cases, and we should implement some kind of wrapper like Degenerable{D} = Union{D, Dirac} where D <: Distribution. In this case, we would probably want to give deprecation warnings on things like Normal(0, 0) and note that accuracy is non-guaranteed.

Alternatively, it might also make sense to have default behavior support degenerate cases and give correct results, even when this yields performance penalties. Prior behavior (with its performance improvements) could be reenabled via alternative methods like minimum(d::Distribution, ::NotDegenerate) with a singleton type NotDegenerate or something. Alternatively, it should theoretically be possible to implement NotDegenerate{D} where D <: Distribution as a wrapper instead.

quildtide commented 1 month ago

Upon further reflection, I believe that the least-breaking change is to:

This way, the default behavior becomes what users think the current behavior is, while users who really want performance and can guarantee that they aren't using degenerate cases can explicitly ask for current-ish behavior.