Closed sunxd3 closed 4 months ago
cc @devmotion, @tpapp, @torfjelde, @yebai, @miguelbiron
It is unfortunate that the ADgradient
constructor takes keywords, not structs, for the legacy interface, since if the gradient spec was all wrapped in a single container then we could just do
ADgradient(get_AD(ℓ), new_ℓ)
and just implement get_AD
instead. If we wait for #29 then we could have that kind of API instead of replace_ℓ
.
If we wait for https://github.com/tpapp/LogDensityProblemsAD.jl/pull/29 then we could have that kind of API instead of replace_ℓ
It is cleaner. We can opt for it after the PR is merged.
But then we would have to change the interface again...
I am inclined to go with replace_ℓ
for now, with a note saying that it is experimental API at the moment and may just change. So the relevant PR in Turing could proceed.
But will wait to hear from @devmotion.
My impression from https://github.com/TuringLang/Turing.jl/pull/2231#issuecomment-2213698650 and related comments in Turing.jl was that there's no clear need for such an API currently? One reason for such an API would be a case where calling ADgradient
from scratch would be less efficient than a dedicated replace_l
function (BTW IMO probably an official API - even if it is experimental - should use a non-Unicode name, potentially with a Unicode alias (but an alias seems a bit much for such a simple functionality)). But at least for the ReverseDiff example here there's no efficiency gain?
Regarding the implementation: Couldn't we achieve this functionality by overloading setproperty!
?
It is unfortunate that the ADgradient constructor takes keywords, not structs, for the legacy interface, since if the gradient spec was all wrapped in a single container then we could just do
Does the ADTypes.jl extension not effectively solve this? Or are there some kwargs that are still missing from the ADTypes.jl structs?
I have a new proposal: add an interface function getADtype
(or some other better name) and don't add the interface function this PR is trying to introduce. getADtype
should return a ADTypes.ADType
. Then packages can just use ADgradient
with ADType
to create the wrapper.
EDIT: just realized this is exactly what @tpapp was suggesting 👍
The motivation is that I don't think replace_l
would be enough. At least for ReverseDiff, one failure mode is that the tape compiled without specifying input (i.e. kwarg x
) can result in a tape that is not correct for all inputs (something related to control flow maybe? @yebai). In that case, we really need the ability to call ADgradient
with kwargs.
Sorry for the late responses, I am on holiday with limited net access.
@torfjelde: the problem is that not all the API is using ADtypes.
@sunxd3: yes, the cleanest solution would be that, see my comment above. But we need to clean up the API first.
I am not sure how pressing is the need for this solution. We could introduce something interim that solves the problem for Turing, with the understanding that it is internal and would be removed once we solve this.
@tpapp understood your point now.
There are motivations to introduce such an API. Correctness of ReverseDiff's tape is one. Do we want to wait for #29 or maybe introduce something like getADtype
, which is just one function that returns the ADType if supported?
the problem is that not all the API is using ADtypes.
Gotcha 👍
I am not sure how pressing is the need for this solution. We could introduce something interim that solves the problem for Turing, with the understanding that it is internal and would be removed once we solve this.
We have a work-around on our side, so I think it's less pressing atm
Ref https://github.com/tpapp/LogDensityProblemsAD.jl/issues/32#issuecomment-2212339696
Brief summary:
replace_ℓ
interface functionADgradient
take in aADGradientWrapper
, then recreate a new gradient wrapper with its log density functionI only added some implementations for
ReverseDiff
.This is very much a draft right now, everything is up to modify.