compintell / Mooncake.jl

https://compintell.github.io/Mooncake.jl/
MIT License
134 stars 7 forks source link

Import more rules from `ChainRules` ? #249

Open yebai opened 1 month ago

yebai commented 1 month ago

Tapir implements rules at a very low level at the moment. This design choice helps reduce the burden of writing and maintaining rules, thanks to the small number of primitives that require manually written rules. The good news is that Tapir seems to have excellent performance even with most of its rules derived automatically by Tapir's autodiff transform (known as DerivedRule).

However, this design choice occasionally (it seems most cases are due to one common root cause, i.e. vectorisation of specific loops) gets in the way of Julia's and LLVM's compiler optimisation passes. Our CI benchmarks (e.g., sum, kron, kron_view_sum, and gp_lml) reflect this. Meanwhile, Zygote, which imports most (or all) rules from ChainRules, performs well in these test cases.

Given that ChainRules is pretty well tested, should we import a more significant number of its rules into Tapir by default? If so, what specific rules should we import?

willtebbutt commented 1 month ago

I'm broadly in favour of selectively importing more rules. I'm working on the interface to import rules for ChainRules.jl at the minute, but the broad picture will remain as it currently is: making use of methods of rrule to implement methods of rrule!! is fine, but we'll need to restrict the argument types to ones that we are happy with. So, for example, if a method of rrule has the signature rrule(::typeof(*), ::AbstractMatrix, ::AbstractMatrix), the appropriate thing to do is not to define a method of Tapir.rrule!! with signature

Tapir.rrule!!(::CoDual{typeof(*)}, ::CoDual{<:AbstractMatrix}, ::CoDual{<:AbstractMatrix})

Rather, we would define rules for concrete types, or (small) finite unions of concrete types, and call out to rrule inside. For example, we might define a method of rrule!! with argument types

Tapir.rrule!!(::CoDual{typeof(*)}, ::CoDual{Matrix{P}}, ::CoDual{Matrix{P}}) where {P<:IEEEFloat}

If so, what specific rules should we import?

I can think of two questions for any given signature:

  1. is a rule for this sufficiently beneficial to warrant adding a rule for it?
  2. is making use of the rule in CR the best way to implement it?

On 1, we have to consider the maintenance burden + change of making mistakes vs performance. For 2: for anything involving heap-allocated data, making use of a CR rrule will probably involve slightly more allocations than would be necessary if we wrote an rrule!! directly. Consequently, for really simple functions like sum, we might want to consider just writing our own rule.

(it seems most cases are due to one common root cause, i.e. vectorisation of specific loops) gets in the way of Julia's and LLVM's compiler optimisation passes

On a technical note, this actually isn't the whole story, it's just one facet. There are several more things that can be done at the Julia SSA IR level to improve performance in these cases that do not involve vectorisation (although may impact how vectorisable the code is as a side-effect). See https://github.com/compintell/Tapir.jl/issues/156