compintell / Mooncake.jl

https://compintell.github.io/Mooncake.jl/
MIT License
136 stars 7 forks source link

Ongoing Limitations Issue #31

Open willtebbutt opened 1 year ago

willtebbutt commented 1 year ago

This issue is going to remain permanently open. It should not ever be closed. Its purpose is to track known test failure cases, and their causes.

Something can be added to this list only if there is a test in our test suite that cannot pass. That is, if you want to track something that doesn't work, a failing test case must be produced and merged into the test suite.

Distributions.jl

  1. Rmath.dnbeta, Rmath.dnchisq, Rmath.dnf, and Rmath.dnt all use ccalls for which it no one has derived an rrule -- this is presumably because they involve infinite series, and it is unclear how to write down a rule for them.
  2. logpdf for Chernoff takes a really long time to run. It involves quadrature, so presumably has quite a substantial inner-loop, which is probably why it takes so long.
  3. logpdf forNormalInverseGamma involves the _besselk function, which contains a ccall for which we have no rrule. Unfortunately, the rrule for besselk returns a ChainRulesCore.NotImplemented w.r.t. its first argument. This means that we cannot simply wrap this rrule, because we would immediately increment!! the tangent, which will cause an error to be thrown by ChainRulesCore. In order to support this, we need to modify this framework to permit inactive arguments to rrule!!s, and to take advantage of this in practice.

Testing Array / AbstractArray functionality in Base and standard libraries

  1. \(::Symmetric, ::Matrix) calls the LAPACK routine sytrf via a ccall. This computes a Bunch-Kauffman factorisation of a symmetric matrix. Unfortunately, this factorisation is somewhat complicated, so writing a rule for it directly (which would be the ideal solution) looks likely to be somewhat involved. It might be that we need to write a rule at a slightly higher level of abstraction -- if we do this, we will need to be careful to restrict the permitted types, and figure out how to handle views.
  2. an invoke call in unique prevents unique from being differentiated, as invoke is not yet supported by Umlaut.
  3. The control flow in Base._unsafe_copyto! depends on the value of the pointers to the arrays passed in. This effectively means that functions which depend on it, such as complex(::Vector{Float64}), are non-deterministic, because the pointer to the memory allocated to store the result of this operation changes each time the function is run. Resolved by moving to the new way of tracing.

Misc Rules

Signatures of primals which need rules, but don't yet have them. Other packages (e.g. ChainRules) may have implementations.

  1. Tuple{typeof(Base.FastMath.pow_fast), Float64, Int}
  2. Tuple{typeof(Base.FastMath.rem_fast), Float64, Float64}
yebai commented 4 months ago

It appears these Rmath-related issues apply to all AD backends. https://github.com/TuringLang/DistributionsAD.jl/blob/df19bf48c380a749b5c71e6a63d24b3335a6e65c/test/ad/distributions.jl#L162

Cc @mhauru

acertain commented 3 months ago

Seems like Tapir doesn't support try/catch (I got a ERROR: Tapir.UnhandledLanguageFeatureException("Encountered UpsilonNode: ϒ (%153)")), maybe this should be mentioned in the docs?

willtebbutt commented 3 months ago

Yeah, this should probably be mentioned. Are you able to provide a MWE for this error? I'm aware that we don't handle UpsilonNodes, but I've not encountered them in many of the programmes that we've tried to differentiate.

edit: if I'm remembering correctly (we definitely document this more carefully), you only really get UpsilonNodes if you're trying to use try / catch blocks to produce control flow. If you just use a try / catch block to catch an error and e.g. throw a different one with more info about what's gone on internally, then you should be fine.

willtebbutt commented 3 months ago

My point being that, while we don't have complete support for all try / catch blocks, it's not that you can't use them at all.

acertain commented 3 months ago

MWE: @warn with string interpolation:

using Tapir, DifferentiationInterface

function f(x)
  @warn "x=$x"
  return x*x
end

DifferentiationInterface.gradient(f, AutoTapir(), 2)