invenia / Nabla.jl

A operator overloading, tape-based, reverse-mode AD
Other
68 stars 5 forks source link

Use ChainRules #189

Closed oxinabox closed 3 years ago

oxinabox commented 4 years ago

Supercedes #178

Follows https://www.juliadiff.org/ChainRulesCore.jl/dev/autodiff/operator_overloading.html (which will probably get updates during this based on practical learnings)

What this PR does:

Things i suggest leaving for potential future PRs

(but that reviewers might disagree with)

Notes on implementation

The core logic is to use of the Operator Overloading interface of ChainRules, which lets you register a hook that is triggered passing in a type- type representing the signature of every primal function that ChainRulesCore has an overload of rrule for.

This hook is the generate_overload function.

This filters out a bunch of things.

It then uses ExprTools to get a AST for function defination that would be suitable for overloading the primal function (as an overloading based AD like Nabla does).

From that it generates: overloads for that primal but with in turn each argument swapped out for the matching node (this is why node_type was added to the code tranformation functions).

And earlier version use unionise_type instead of swapping it out, but for things with primal type of Any (which shows up for nondifferentiable_rule), this just resulted in Union{Node{Any}, Any} which simplifies to Any. Which mean we were overwriting the original primal definition which will break everything.

The key thing these generated primal overloads do is create a Branch that stores the pullback.

We then generate a method for preprocess which invokes that pullback, computing the partials for all the arguments. \ And we generate a method for that just talkes that partial computed by preprocess and return the right one for the specified Arg{N}.

Things to do before Review

Things for reviewers to consider:

oxinabox commented 4 years ago

to do the inplace I would like to have https://github.com/JuliaDiff/ChainRulesCore.jl/issues/113#issuecomment-675010293

but i don't need it since can just overload update! for InplaceableThunk

oxinabox commented 4 years ago

Needs https://github.com/invenia/ExprTools.jl/pull/12

oxinabox commented 4 years ago

Probably what this should do is look at the method table and check if the simple unionized overload would eclipse any in the wrong way (what is that? I need to think carefully). And if not, use that. But if so use the one where it generate all the combinatoric overloads.

Though maybe that check would take longer than the extra processing time to generate and load all of them

oxinabox commented 4 years ago

Effect on load-time. This is definately much slower to load. This time is with a recent build of 1.6. But matches to 1.5 roughly.

This PR

julia> @time @time using Nabla
 12.162363 seconds (17.81 M allocations: 1.070 GiB, 3.15% gc time, 46.90% compilation time)
 12.218878 seconds (17.95 M allocations: 1.079 GiB, 3.13% gc time, 47.14% compilation time)

julia> length(methods(∇))
717

julia> using SnoopCompileCore

julia> invalidations = @snoopr begin
       using Nabla
       end;

julia> using SnoopCompile

julia> length(uinvalidated(invalidations))
3267

Current release:

julia> @time @time using Nabla
  0.847823 seconds (1.01 M allocations: 66.097 MiB, 16.50% gc time, 61.89% compilation time)
  0.891765 seconds (1.13 M allocations: 73.762 MiB, 15.69% gc time, 63.64% compilation time)

julia> length(methods(∇))
376

julia> using SnoopCompileCore

julia> invalidations = @snoopr begin
       using Nabla
       end;

julia> using SnoopCompile

julia> length(uinvalidated(invalidations))
92

@keno you mentioned being interested in this.

oxinabox commented 3 years ago

Dropped a bunch of rule generation for nondifferentiable things. In particular for ones that were causing invalidations.

Right now we don't really make good code for non-differentiable things anyway -- we should generate totally different code that doesn't return the a Branch but rather just returns the primal result. Right now there is no easy way to identify them however: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/248

With those changes we are down to 168 invalidations (mostly not from Nabla) and startup time is improved 9.8s. Which is better than 12.2s but a far cry from the 0.8 seconds Nabla used to take.

oxinabox commented 3 years ago

While i remember, i should check that we are doing efficient things if the input is an AbstractZero

oxinabox commented 3 years ago

I thought i was done, then I realizes that i could block it from erroring when new rules were added to chainrules that were also still in Nabla, by making a list of all rules that we still have, and adding them to our block list.

Also I realised the docs wouldn't build anymore. because we were using a version of Documenter that was so old that it wasn't compatible with Compat.jl 0.3 (which ChainRules uses)

Should be all sorted now