tpapp / LogDensityProblemsAD.jl

AD backends for LogDensityProblems.jl.
MIT License
12 stars 6 forks source link

Add interface functions to allow replacing the log density function and replacing AD wrapper type #33

Closed sunxd3 closed 4 months ago

sunxd3 commented 4 months ago

Ref https://github.com/tpapp/LogDensityProblemsAD.jl/issues/32#issuecomment-2212339696

Brief summary:

I only added some implementations for ReverseDiff.

This is very much a draft right now, everything is up to modify.

sunxd3 commented 4 months ago

cc @devmotion, @tpapp, @torfjelde, @yebai, @miguelbiron

tpapp commented 4 months ago

It is unfortunate that the ADgradient constructor takes keywords, not structs, for the legacy interface, since if the gradient spec was all wrapped in a single container then we could just do

ADgradient(get_AD(ℓ), new_ℓ)

and just implement get_AD instead. If we wait for #29 then we could have that kind of API instead of replace_ℓ.

sunxd3 commented 4 months ago

If we wait for https://github.com/tpapp/LogDensityProblemsAD.jl/pull/29 then we could have that kind of API instead of replace_ℓ

It is cleaner. We can opt for it after the PR is merged.

tpapp commented 4 months ago

But then we would have to change the interface again...

I am inclined to go with replace_ℓ for now, with a note saying that it is experimental API at the moment and may just change. So the relevant PR in Turing could proceed.

But will wait to hear from @devmotion.

devmotion commented 4 months ago

My impression from https://github.com/TuringLang/Turing.jl/pull/2231#issuecomment-2213698650 and related comments in Turing.jl was that there's no clear need for such an API currently? One reason for such an API would be a case where calling ADgradient from scratch would be less efficient than a dedicated replace_l function (BTW IMO probably an official API - even if it is experimental - should use a non-Unicode name, potentially with a Unicode alias (but an alias seems a bit much for such a simple functionality)). But at least for the ReverseDiff example here there's no efficiency gain?

Regarding the implementation: Couldn't we achieve this functionality by overloading setproperty!?

torfjelde commented 4 months ago

It is unfortunate that the ADgradient constructor takes keywords, not structs, for the legacy interface, since if the gradient spec was all wrapped in a single container then we could just do

Does the ADTypes.jl extension not effectively solve this? Or are there some kwargs that are still missing from the ADTypes.jl structs?

sunxd3 commented 4 months ago

I have a new proposal: add an interface function getADtype (or some other better name) and don't add the interface function this PR is trying to introduce. getADtype should return a ADTypes.ADType. Then packages can just use ADgradient with ADType to create the wrapper.

EDIT: just realized this is exactly what @tpapp was suggesting 👍

The motivation is that I don't think replace_l would be enough. At least for ReverseDiff, one failure mode is that the tape compiled without specifying input (i.e. kwarg x) can result in a tape that is not correct for all inputs (something related to control flow maybe? @yebai). In that case, we really need the ability to call ADgradient with kwargs.

tpapp commented 4 months ago

Sorry for the late responses, I am on holiday with limited net access.

@torfjelde: the problem is that not all the API is using ADtypes.

@sunxd3: yes, the cleanest solution would be that, see my comment above. But we need to clean up the API first.

I am not sure how pressing is the need for this solution. We could introduce something interim that solves the problem for Turing, with the understanding that it is internal and would be removed once we solve this.

sunxd3 commented 4 months ago

@tpapp understood your point now.

There are motivations to introduce such an API. Correctness of ReverseDiff's tape is one. Do we want to wait for #29 or maybe introduce something like getADtype, which is just one function that returns the ADType if supported?

torfjelde commented 4 months ago

the problem is that not all the API is using ADtypes.

Gotcha 👍

I am not sure how pressing is the need for this solution. We could introduce something interim that solves the problem for Turing, with the understanding that it is internal and would be removed once we solve this.

We have a work-around on our side, so I think it's less pressing atm