tpapp / LogDensityProblemsAD.jl

AD backends for LogDensityProblems.jl.
MIT License
12 stars 6 forks source link

Support input-dependent ForwardDiff tags #18

Closed devmotion closed 1 year ago

devmotion commented 1 year ago

The ForwardDiff tags are by default dependent on the type of the input (see, e.g., https://github.com/JuliaDiff/ForwardDiff.jl/blob/e3670ce9055c66863f655d2bac2d6615c165d838/src/config.jl#L120 and https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/). This is typically also the case for custom tags (see, e.g., again https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ and https://github.com/TuringLang/Turing.jl/blob/e32bb71accb4d71dfc5f0377984a00aa5d643c3a/src/Turing.jl#L30-L36).

Unfortunately, with the current setup in LogDensityProblemsAD one can only specify the exact ForwardDiff.Tag and hence packages such as Turing have to compute it based on the type of the input and then forward it to ADgradient (see https://github.com/TuringLang/Turing.jl/blob/e32bb71accb4d71dfc5f0377984a00aa5d643c3a/src/essential/ad.jl#L96).

With this PR the whole procedure becomes more convenient for downstream packages: If the specified tag is not a ForwardDiff.Tag, the gradient computations will use ForwardDiff.Tag(tag, eltype(x)) instead. Thus the element type of x is taken into consideration automatically.