Open oxinabox opened 5 years ago
Dumpinmg this chunk of coder here for now:
"""
checked_fgrad(f, a, b, ...; Δ1=One())
For a function `f(a, b, ...)` returns `∂f_∂a, ∂f_∂b,...` at the point (a, b, ...).
Computed via the forward-mode frule.
Note: it is much more efficient to calculate this using the reverse-mode `rrule`
This is calculated with pertubation (gradient chained in from below)
`Δ`, set to `Δ1` (default `One()`) for each input in turn,
with the otherse set to zero.
!!! note
This skips the internal derviative, merely testing that for
nonfunctors that it does not change the value.
"""
function checked_fgrad(f, x... ; Δ1=One())
res = frule(f, x...)
@test res !== nothing # ensure rule defined
Y, pushforward = res
@test Y == f(x...)
pertubation_length = length(x) + 1
if fieldcount(typeof(f)) === 0
nil_Δ = ntuple(_->Zero(), pertubation_length)
nil_∂ = pushforward(nil_Δ...)
Δ = onehot(pertubation_length, 1)
@test extern.(pushforward(Δ...)) ≈ extern.(nil_∂)
end
∂s = map(2:pertubation_length) do hot_ind
Δ = onehot(pertubation_length, hot_ind)
extern.(pushforward(Δ...))
end
return ∂s
end
function onehot(len, hot_ind, hotval=One(), coldval=Zero())
ntuple(len) do ii
ii == hot_ind ? hotval : coldval
end
end
"""
checked_rgrad(f, a, b, ...; Δ=One())
For a function `f(a, b, ...)` returns `∂f_∂a, ∂f_∂b,...` at the point (a, b, ...).
Computed via the reverse mode `rrule`.
This is calculated with seed (gradient chained in from above) `Δ`,
defaulting to `One`.
!!! note
This skips the internal derviative, merely testing that for
nonfunctors that it is `NO_FIELDS`.``
"""
function checked_rgrad(f, x... ; Δ=One())
res = rrule(f, x...)
@test res !== nothing # ensure rule defined
Y, pullback = res
@test Y == f(x...)
∂s = pullback(Δ)
if fieldcount(typeof(f)) === 0
@test first(∂s) === NO_FIELDS
end
return extern.(collect(∂s[2:end]))
end
Simeon Schaub 17 hours ago
pushforward = (da, db) -> sum(pullback(One()) .* (da, db))? (edited)
Lyndon White:ox: 17 hours ago
Makes sense. I should read more on differentiable geometry
Simeon Schaub 17 hours ago
Although if you're dealing with arrays, things are more complicated I think. There probably needs to be a transpose somewhere in there.
In the same sense of #44 we might want some invocation helpers.
Something like
Which is what I was using for testing stuff with Zygote.