JuliaStats / StatsFuns.jl

Mathematical functions related to statistics.
Other
235 stars 40 forks source link

TDist - cdf and quantile function - Auto-Differentiability #152

Open paschermayr opened 1 year ago

paschermayr commented 1 year ago

Hi there,

Thank you for all your work! I have seen that recently, pull request https://github.com/JuliaStats/StatsFuns.jl/pull/147 was closed in favor of https://github.com/JuliaStats/StatsFuns.jl/pull/149.

I believe the former issue tried to make the cdf and quantile function of the TDistribution auto-differentiable, but the latter closed it and did not alleviate this issue. MWE from a fresh project with all up-to-date libraries:

using Distributions, DistributionsAD, StatsBase
using ForwardDiff, ReverseDiff
using StatsFuns

function mytargetfunction(data::AbstractVector)
    function obtaingradient(θ::AbstractVector{R}) where {R<:Real}
        nu = θ[1]
        distr = TDist(nu)
        data_uniform = [cdf(distr, data[iter]) for iter in eachindex(data)]
        data_real = [quantile(distr, data_uniform[iter]) for iter in eachindex(data_uniform)]
        return sum( logpdf(distr, data_real[iter]) for iter in eachindex(data_real) )
    end
end

#working
ν = [3.0]
data = randn(1000)
target = mytargetfunction(data)
target(ν)
#not working
ForwardDiff.gradient(target, ν) #MethodError: no method matching _beta_inc(::ForwardDiff.Dual
ReverseDiff.gradient(target, ν) #MethodError: no method matching _beta_inc(::ReverseDiff.TrackedReal

It seems like the beta_inc function is from the Specialfunctions.jl package and requires Float64 as arguments instead of just reals. Is there a reason for that? I believe I should probably open an issue there as well?

andreasnoack commented 1 year ago

We should probably have a general derivative rule for cdf defined somewhere. @devmotion any thoughts?

devmotion commented 1 year ago

Hmm, for ChainRules-compatible AD systems we can add the missing rules just in StatsFuns (or, of course, SpecialFunctions directly if the rule is missing there). I think we might want to make ChainRulesCore a weak dependency in the future anyway on Julia >= 1.9, and then the amount of definitions should not matter for loading and compilation times if users do not use ChainRules.

We could also add definitions for ForwardDiff and ReverseDiff by making them weak dependencies. That could fix the issue at least on Julia >= 1.9. Maybe even better would be to make DiffRules a weak dependency (which they use for defining rules automatically instead of ChainRules - even though there are approaches to bridge them with ChainRules, they would be type piracy here: https://github.com/ThummeTo/ForwardDiffChainRules.jl and https://juliadiff.org/ReverseDiff.jl/dev/api/#ChainRules-integration) but the current design of DiffRules does not allow to reliably add new rules in other packages, i.e., they might not be picked up by e.g. ForwardDiff and ReverseDiff since they only define their differentiation rules once when they are loaded based on the rules that are available at that time point. @KristofferC was looking into some of the issues with the current design of DiffRules: https://github.com/JuliaDiff/DiffRules.jl/issues/90

So at least on Julia >= 1.9, maybe the best short-term solution would be to add weak dependencies on ReverseDiff and ForwardDiff, and define rules for them explicitly. And to add missing ChainRules definitions.

(As a side remark, #147 also used beta_inc and beta_inc_inv, so - without testing it - I would assume the same AD issues as reported above would show up there as well.)

andreasnoack commented 1 year ago

My point is that the derivative of x -> cdf(..., x) is readily available. However, supporting all the partial derivates of betacdf will require more work.

devmotion commented 1 year ago

Yes, that's why I assumed you might want to add rules to StatsFuns (or, e.g., Distributions) instead of SpecialFunctions (even though in the example above you would need rules for theta -> cdf(dist(theta), x) as well). But to me it seemed that in cases where the derivatives are readily available the question is still how to deal with AD systems such as ForwardDiff and ReverseDiff that only support DiffRules but not ChainRules.