JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
252 stars 62 forks source link

Invocation helpers #45

Open oxinabox opened 5 years ago

oxinabox commented 5 years ago

In the same sense of #44 we might want some invocation helpers.

Something like

function grad(f, args...)
    _, pullback = rrule(f, args...)
    partials = pullback(One)
    return extern.(partials[2:end])
end

Which is what I was using for testing stuff with Zygote.

oxinabox commented 5 years ago

Dumpinmg this chunk of coder here for now:


"""
    checked_fgrad(f, a, b, ...; Δ1=One())

For a function `f(a, b, ...)` returns `∂f_∂a, ∂f_∂b,...` at the point (a, b, ...).
Computed via the forward-mode frule.
Note: it is much more efficient to calculate this using the reverse-mode `rrule`

This is calculated with pertubation (gradient chained in from below)
`Δ`, set to `Δ1` (default `One()`) for each input in turn,
with the otherse set to zero.

!!! note
    This skips the internal derviative, merely testing that for
    nonfunctors that it does not change the value.
"""
function checked_fgrad(f, x... ; Δ1=One())
    res = frule(f, x...)
    @test res !== nothing  # ensure rule defined
    Y, pushforward = res
    @test Y == f(x...)

    pertubation_length = length(x) + 1
    if fieldcount(typeof(f)) === 0
        nil_Δ = ntuple(_->Zero(), pertubation_length)
        nil_∂ = pushforward(nil_Δ...)

        Δ = onehot(pertubation_length, 1)
        @test extern.(pushforward(Δ...)) ≈ extern.(nil_∂)
    end

    ∂s = map(2:pertubation_length) do hot_ind
        Δ = onehot(pertubation_length, hot_ind)
        extern.(pushforward(Δ...))
    end
    return ∂s
end

function onehot(len, hot_ind, hotval=One(), coldval=Zero())
    ntuple(len) do ii
        ii == hot_ind ? hotval : coldval
    end
end

"""
    checked_rgrad(f, a, b, ...; Δ=One())

For a function `f(a, b, ...)` returns `∂f_∂a, ∂f_∂b,...` at the point (a, b, ...).
Computed via the reverse mode `rrule`.

This is calculated with seed (gradient chained in from above) `Δ`,
defaulting to `One`.

!!! note
    This skips the internal derviative, merely testing that for
    nonfunctors that it is `NO_FIELDS`.``
"""
function checked_rgrad(f, x... ; Δ=One())
    res = rrule(f, x...)
    @test res !== nothing  # ensure rule defined
    Y, pullback = res
    @test Y == f(x...)
    ∂s = pullback(Δ)
    if fieldcount(typeof(f)) === 0
        @test first(∂s) === NO_FIELDS
    end
    return extern.(collect(∂s[2:end]))
end

Simeon Schaub 17 hours ago

pushforward = (da, db) -> sum(pullback(One()) .* (da, db))? (edited)

Lyndon White:ox: 17 hours ago

Makes sense. I should read more on differentiable geometry

Simeon Schaub 17 hours ago

Although if you're dealing with arrays, things are more complicated I think. There probably needs to be a transpose somewhere in there.