Open yebai opened 1 month ago
I'm broadly in favour of selectively importing more rules. I'm working on the interface to import rules for ChainRules.jl at the minute, but the broad picture will remain as it currently is: making use of methods of rrule
to implement methods of rrule!!
is fine, but we'll need to restrict the argument types to ones that we are happy with. So, for example, if a method of rrule
has the signature rrule(::typeof(*), ::AbstractMatrix, ::AbstractMatrix)
, the appropriate thing to do is not to define a method of Tapir.rrule!!
with signature
Tapir.rrule!!(::CoDual{typeof(*)}, ::CoDual{<:AbstractMatrix}, ::CoDual{<:AbstractMatrix})
Rather, we would define rules for concrete types, or (small) finite unions of concrete types, and call out to rrule
inside. For example, we might define a method of rrule!!
with argument types
Tapir.rrule!!(::CoDual{typeof(*)}, ::CoDual{Matrix{P}}, ::CoDual{Matrix{P}}) where {P<:IEEEFloat}
If so, what specific rules should we import?
I can think of two questions for any given signature:
On 1, we have to consider the maintenance burden + change of making mistakes vs performance. For 2: for anything involving heap-allocated data, making use of a CR rrule
will probably involve slightly more allocations than would be necessary if we wrote an rrule!!
directly. Consequently, for really simple functions like sum
, we might want to consider just writing our own rule.
(it seems most cases are due to one common root cause, i.e. vectorisation of specific loops) gets in the way of Julia's and LLVM's compiler optimisation passes
On a technical note, this actually isn't the whole story, it's just one facet. There are several more things that can be done at the Julia SSA IR level to improve performance in these cases that do not involve vectorisation (although may impact how vectorisable the code is as a side-effect). See https://github.com/compintell/Tapir.jl/issues/156
Tapir implements rules at a very low level at the moment. This design choice helps reduce the burden of writing and maintaining rules, thanks to the small number of primitives that require manually written rules. The good news is that Tapir seems to have excellent performance even with most of its rules derived automatically by Tapir's autodiff transform (known as
DerivedRule
).However, this design choice occasionally (it seems most cases are due to one common root cause, i.e. vectorisation of specific loops) gets in the way of Julia's and LLVM's compiler optimisation passes. Our CI benchmarks (e.g.,
sum
,kron
,kron_view_sum
, andgp_lml
) reflect this. Meanwhile, Zygote, which imports most (or all) rules fromChainRules
, performs well in these test cases.Given that
ChainRules
is pretty well tested, should we import a more significant number of its rules into Tapir by default? If so, what specific rules should we import?