JuliaStats / StatsBase.jl

Basic statistics for Julia
Other
584 stars 194 forks source link

Zygote rrule for percentile() #832

Open renatobellotti opened 2 years ago

renatobellotti commented 2 years ago

I would like to suggest a Zygote rrule for the StatsBase.jl function. I'm not sure whether to post this in the Zygote or in this repo, so please let me know if this is the wrong place.

using Zygote
using ChainRulesCore

function ChainRulesCore.rrule(::typeof(StatsBase.percentile), x::AbstractArray{T, N}, p) where {T, N}
    percentiles = StatsBase.percentile(x, p)

    function percentile_pullback(percentiles_bar)
        f̄ = NoTangent()
        x̄ = zeros(T, length(x))
        for percentile in percentiles
            for i in findall(x .== percentile)
                x̄[i] = one(T)
            end
        end
        p̄ = NoTangent()

        return f̄, x̄, p̄
    end

    return percentiles, percentile_pullback
end

x = [0., 1., 2., 3, 4, 5, 6]
@show g = Zygote.gradient(x -> StatsBase.percentile(x, 60), x)
@show J = Zygote.jacobian(x -> StatsBase.percentile(x, [50, 60]), x)

I'm happy to hear your thoughts!

nalimilan commented 2 years ago

I guess it would have to live in StatsBase. But you'll have to ping ChainRules or Zygote developers to review the PR as I'm not familiar with these.

I can just note that percentile is just a convenience function calling quantile so both should probably be treated the same way. Also relevant is that multiple definitions of quantile are supported (depending on passed keyword arguments) and I'm not sure how ChainRules handles that.

renatobellotti commented 1 year ago

Thanks for your answer. I was able to solve my problem without relying on percentile()/quantile(). I will leave this open in case somebody needs it in the future.