Closed oxinabox closed 3 years ago
to do the inplace I would like to have https://github.com/JuliaDiff/ChainRulesCore.jl/issues/113#issuecomment-675010293
but i don't need it since can just overload update!
for InplaceableThunk
Probably what this should do is look at the method table and check if the simple unionized overload would eclipse any in the wrong way (what is that? I need to think carefully). And if not, use that. But if so use the one where it generate all the combinatoric overloads.
Though maybe that check would take longer than the extra processing time to generate and load all of them
Effect on load-time. This is definately much slower to load. This time is with a recent build of 1.6. But matches to 1.5 roughly.
julia> @time @time using Nabla
12.162363 seconds (17.81 M allocations: 1.070 GiB, 3.15% gc time, 46.90% compilation time)
12.218878 seconds (17.95 M allocations: 1.079 GiB, 3.13% gc time, 47.14% compilation time)
julia> length(methods(∇))
717
julia> using SnoopCompileCore
julia> invalidations = @snoopr begin
using Nabla
end;
julia> using SnoopCompile
julia> length(uinvalidated(invalidations))
3267
julia> @time @time using Nabla
0.847823 seconds (1.01 M allocations: 66.097 MiB, 16.50% gc time, 61.89% compilation time)
0.891765 seconds (1.13 M allocations: 73.762 MiB, 15.69% gc time, 63.64% compilation time)
julia> length(methods(∇))
376
julia> using SnoopCompileCore
julia> invalidations = @snoopr begin
using Nabla
end;
julia> using SnoopCompile
julia> length(uinvalidated(invalidations))
92
@keno you mentioned being interested in this.
Dropped a bunch of rule generation for nondifferentiable things. In particular for ones that were causing invalidations.
Right now we don't really make good code for non-differentiable things anyway -- we should generate totally different code that doesn't return the a Branch
but rather just returns the primal result.
Right now there is no easy way to identify them however: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/248
With those changes we are down to 168 invalidations (mostly not from Nabla) and startup time is improved 9.8s. Which is better than 12.2s but a far cry from the 0.8 seconds Nabla used to take.
While i remember, i should check that we are doing efficient things if the input is an AbstractZero
I thought i was done, then I realizes that i could block it from erroring when new rules were added to chainrules that were also still in Nabla, by making a list of all rules that we still have, and adding them to our block list.
Also I realised the docs wouldn't build anymore. because we were using a version of Documenter that was so old that it wasn't compatible with Compat.jl 0.3 (which ChainRules uses)
Should be all sorted now
Supercedes #178
Follows https://www.juliadiff.org/ChainRulesCore.jl/dev/autodiff/operator_overloading.html (which will probably get updates during this based on practical learnings)
What this PR does:
Nabla.update!
swapped out for the very similaradd!!
so it would work forInplaceableThunk
sForwardDiff.derivative
but for multiargument things it still uses the Dual numbers directly.node_type
tranform that is likeunionise_type
transform but without making the unionpreprocess
not receive its inputs pre-unboxe-d, but have the default fallback unbox them and recall processVarArg{T, N} where N
add support for SpecialFunction 0.10Drops support for SpecialFunctions 0.9.
lgamma
/loggamma
as they don't both exist in nondeprecated form in same versionThings i suggest leaving for potential future PRs
(but that reviewers might disagree with)
:no_N
in src/conde_tranformations/utils.jl. Which doesn’t seem to ever be hitPair{Node}
returnNode{Pair}
for consistency and sodiagm
will hit rules we define in ChainRules.jlNotes on implementation
The core logic is to use of the Operator Overloading interface of ChainRules, which lets you register a hook that is triggered passing in a type- type representing the signature of every primal function that ChainRulesCore has an overload of
rrule
for.This hook is the
generate_overload
function.This filters out a bunch of things.
It then uses ExprTools to get a AST for function defination that would be suitable for overloading the primal function (as an overloading based AD like Nabla does).
From that it generates: overloads for that primal but with in turn each argument swapped out for the matching node (this is why
node_type
was added to the code tranformation functions).And earlier version use
unionise_type
instead of swapping it out, but for things with primal type ofAny
(which shows up fornondifferentiable_rule
), this just resulted inUnion{Node{Any}, Any}
which simplifies toAny
. Which mean we were overwriting the original primal definition which will break everything.The key thing these generated primal overloads do is create a
Branch
that stores the pullback.We then generate a method for
preprocess
which invokes that pullback, computing the partials for all the arguments. \ And we generate a method for∇
that just talkes that partial computed by preprocess and return the right one for the specifiedArg{N}
.Things to do before Review
Things for reviewers to consider:
src/sensitivities/chainrules.jl
filemap
.Pair{<:Node, <:Node}
, rather thanNode{<:Pair}
rrule
is added to ChainRules for something Nabla has it will cause Nabla to break due to ambiguity.Arg{1}
andArg{2}
etc cases. That would remove ambiguities i think.