tpapp / LogDensityProblemsAD.jl

AD backends for LogDensityProblems.jl.
MIT License
12 stars 6 forks source link

Use DI for non-implemented ADTypes #39

Closed gdalle closed 2 weeks ago

gdalle commented 1 month ago

This PR adds a teeny tiny extension for DifferentiationInterface (#26). It can compute gradients for any ADTypes.AbstractADType that is not in the following list:

That way, your custom implementations remain the default, but for all other AD backends defined by ADTypes (and not symbols), DifferentiationInterface will kick in. This also allows you to gradually remove custom implementations in favor of DifferentiationInterface, if you so desire.

Ping @willtebbutt @torfjelde @adrhill

Note: since DI imposes Enzyme v0.13 in the tests, it may require merging #38 first.

gdalle commented 1 month ago

@tpapp @devmotion thoughts? This is a strict addition of features, it does not modify any of the existing dispatches.

gdalle commented 1 month ago

I think the test errors are due to the breaking version of Enzyme, which is why #38 might have higher priority

tpapp commented 1 month ago

Thanks for the ping, I was a bit busy last week to review this.

This looks like a very lightweight addition that at the same time enables the use of DifferentiationInterfaces (for all supported backends), which extends the functionality of the package, and in the long run also allows replacing existing backends with DI as the code matures.

Tests currently do not run, I think Enzyme compat needs to be broadened.

gdalle commented 1 month ago

No worries, thanks for the review, I'll take your remarks into account.

Tests currently do not run, I think Enzyme compat needs to be broadened.

Not possible, the Enzyme v0.13 change was very breaking and DI cannot afford to support every past version, so I used their breaking change as an opportunity to tag mine as well. Perhaps as a temporary solution we could run the DI tests in another environment where Enzyme is not?

tpapp commented 1 month ago

@gdalle: note: we just merged #38.

gdalle commented 1 month ago

@willtebbutt does this clash with the Mooncake extension for LogDensityProblemsAD?

willtebbutt commented 1 month ago

I have no idea -- if it does, I'm more than happy to remove my extension and rely on the contents of this PR. Me having to look after less code is never a problem.

tpapp commented 1 month ago

@gdalle: I am wondering if a wrapper like

with_preparation(ADgradient(backend, ℓ), zeros(3)))

could provide a reasonable API, without keywords. Would not even need a separate DIgradient struct, existing could default to prep = nothing and the above would just replace it with x.

gdalle commented 1 month ago

My idea here was to mimick the existing API as closely as possible. Some constructors of ADgradient using symbols can also take an x as an optional keyword argument: https://github.com/tpapp/LogDensityProblemsAD.jl/blob/2ce49ce6705bbf35e46ee328f793b9eaaf78546c/ext/LogDensityProblemsADForwardDiffExt.jl#L96-L99 https://github.com/tpapp/LogDensityProblemsAD.jl/blob/2ce49ce6705bbf35e46ee328f793b9eaaf78546c/ext/LogDensityProblemsADReverseDiffExt.jl#L45-L47 They also take other kwargs like config or compile information, but with ADTypes this is stored in the backend object itself so we no longer need to pass it

gdalle commented 1 month ago

Tests pass locally

gdalle commented 1 month ago

@devmotion @tpapp is this better with the latest changes?

tpapp commented 1 month ago

@gdalle: Thanks for the recent updates. I understand and appreciate that you want to keep the API consistent with the existing one.

However, that API predates the AD-unification libraries (like DI) and is not particularly well designed because it does not reify the AD process. Specifically, I now believe that the ideal API would be something like

ADgradient(how, P)

where P is a ℝⁿ→ℝ function and how contains all information on how to AD.

In contrast, currently we have

AGgradient(how_backend, P; how_details...)

and your PR (in its current state) extends the existing API in this direction.

In fact, DI does not reify how either: if you want preparation you do it via one of the API functions.

So there are two questions:

  1. do we want to keep the existing API AGgradient(how_backend, P; how_details...), either in the short run or forever,
  2. if not, do we take this opportunity to improve it,
  3. can DI provide an API that reifies how in a way that makes sense (I am assuming this is possible, please correct me if it is not).

I appreciate your work on DI and your PRs here, and please understand that I am not pushing back on changes. I think DI is a great idea, but I want to do it right so that this package breaks its own API the fewest times possible (eventually, I want to encourage users to move on the the new API, and deprecate & remove the existing one).

@devmotion, what do you think?

gdalle commented 1 month ago

Thanks for your kind answer @tpapp, and for your work on this part of the ecosystem.

  1. do we want to keep the existing API ADgradient(how_backend, P; how_details...), either in the short run or forever
  2. if not, do we take this opportunity to improve it,

In my view, the AD extensions of LogDensityProblemsAD filled a crucial void when DI did not exist. Now that DI is in a pretty good state, I don't know if this ADGradient API will remain necessary for much longer. Thus, my proposal was a minimally invasive insertion, designed to encourage gradual pivots to DI in the future without needing breaking changes here or in Turing. Perhaps someday, when DI is truly ready, we won't even need LogDensityProblemsAD at all?

Of course, to get peak performance or avoid some bugs, you still want to tune the bindings for every backend. But if every Julia package does that separately, it is a huge waste of time and LOCs. My hope is that this tuning can be done in a single place and fit 99% of use cases, which is what DI is for. I'm always open to suggestions for performance or design improvements. Besides, the case that we are tackling here (gradient of array-input function with constant contexts) is exactly the case where we can be extremely performant with DI, which makes it a prime candidate for the switch.

  1. can DI provide an API that reifies how in a way that makes sense (I am assuming this is possible, please correct me if it is not).

The DI interface with preparation looks like this:

gradient(f, prep, backend, x, contexts...)

where backend is an object from ADTypes.jl and prep is the result of

prepare_gradient(f, backend, typical_x, typical_contexts...)

In your terms:

This shows that there are two sides to the how, and I think it makes sense to distinguish them.

gdalle commented 1 month ago

So where do you wanna go from here?

tpapp commented 1 month ago

Perhaps someday, when DI is truly ready, we won't even need LogDensityProblemsAD at all?

Possibly, but that is speculation. At the moment, there is no generic AD wrapper interface that provides what this package does. Preparation, as you explained above, is one example.

So where do you wanna go from here?

I want to reflect a bit on this, and also hear comments from the users.

Currently I am leaning towards cleaning up the interface the following way:

  1. ADgradient(how, P) where how encapculates everything we need for AD,

  2. each backend gets a constructor that replaces the current Val{symbol} and backend::Symbol API. This constructor takes keywords and wharever is needed.

We could of course merge your PR as is, then later deprecate this.

gdalle commented 1 month ago

At the moment, there is no generic AD wrapper interface that provides what this package does.

Well, I would love for DI to provide this. What do you think is missing then?

We could of course merge your PR as is, then later deprecate this.

The idea of this PR was to be minimally invasive, so that you can gradually drop extensions in favor of a better-maintained and tested DI implementation. Therefore, I think it is a good idea to merge it before a possible breaking revamp of LogDensityProblemsAD, especially if you want to use more of DI in the revamp.

tpapp commented 1 month ago

What do you think is missing then?

A way to pack everything in how (including prep, and whatever is needed), as explained above.

I think it is a good idea to merge it before a possible breaking revamp of LogDensityProblemsAD

As I said above, that is a possibility I am considering. I will wait for thought from @devmotion.

gdalle commented 1 month ago

A way to pack everything in how (including prep, and whatever is needed), as explained above.

If you want to use only DI, this is as simple as something like

struct How{B,P}
    backend::B
    prep::P
end

But if you want to also adapt this to your existing extensions, then of course it's a bit more work. I'll let you weigh the pros and cons.

devmotion commented 1 month ago

Hmm... I think conceptually keeping backend and prep separated feels a bit cleaner to me. There's information about the desired AD backend that is independent from the log density problem, its dimension etc (e.g., I want to use Enzyme + reverse mode) and there's information that depends on the problem at hand (e.g., type and length of the input to the log density function, the function itself). Having them separate makes it easier to pass the problem-independent settings around and reuse them for other log-density problems. For instance, in Turing a user might want to specify the AD backend to be used by the sampler, but at that time point (when constructing the sampler) the actual log-density problem is not created yet (that only happens internally in Turing).

What I dislike about the Val interface is that it does not allow to pass around any additional information apart from the AD backend, and hence the keyword arguments contain both problem-independent information (like the Enzyme mode or ForwardDiff chunk size or tags) and problem-dependent information (like a typical input).

I think deprecating or directly replacing the Val interface with the ADTypes interface would resolve this issue. Everything that's problem-independent you could store and reuse by specifying the ADType, and problem-dependent settings such as typical inputs you could specify with keyword arguments.

gdalle commented 1 month ago

@devmotion I agree that ADTypes are overall more expressive than symbols, which is why they were introduced. But even deprecating the Val API won't solve the heterogeneity between backend. Currently, you need to pass different keyword arguments depending on which backend you want to use (shadow for forward Enzyme, chunks for ForwardDiff, etc.). The appeal of DI is to perform this preparation in the same way everywhere, so that the user can just pass x and switch backends transparently while preserving performance.

gdalle commented 1 month ago

In my previous attempt #29, the main obstacles to full DI adoption were

The first one has been resolved, the second one is much more a Tracker issue than a DI one. @tpapp concluded his review of #29 by saying (emphasis mine)

Yes, this package by necessity and historical reasons duplicates a lot of functionality in an abstract AD metapackage. This was made much easier by the fact that we only care about R^n → R functions. But the code is already there and in most cases it works fine.

Sure, your own AD interface has already been written, but it still needs to be updated and whenever any backend changes (e.g. #37 and #38 for the latest Enzyme). Since DI is to become the standard (already used in Optimization.jl, NonlinearSolve.jl and more), it will remain actively maintained and react to evolutions of the ecosystem (like the new Mooncake.jl package). The way things work at the moment, you also need to perform the same adaptations in parallel, or give up on the latest features, both of which are a bit of a waste.

devmotion commented 1 month ago

But even deprecating the Val API won't solve the heterogeneity between backend. Currently, you need to pass different keyword arguments depending on which backend you want to use (shadow for forward Enzyme, chunks for ForwardDiff, etc.).

I think it would. I think the only keyword left should be a typical input x. The other options seem to belong and are part of the ADTypes: mode for Enzyme, fdm for FiniteDifferences, tag and chunk for ForwardDiff, and compile for ReverseDiff. shadow is a bit strange but I don't think anyone has ever used it and it could be constructed based on the typical x, so I think it should be removed.

gdalle commented 1 month ago

Yes you're right, shadow was the only example in the category of "not backend, not x".

So if Tamas agrees, I guess the question is whether you want to deprecate Vals by switching directly to DI, or first deprecate it on your own.

tpapp commented 1 month ago

@gdalle: I would prefer to do it like this:

  1. add an ADgradient method that implements via DI. It should not dispatch on ADtypes though, the user should indicate that they want DI specifically. It is my expectation that in the long run, calling ADgradient on ADtypes directly will dispatch to this method, but I want to keep this level of indirection. We can work out the syntax, suggestions welcome.

  2. once that is in place, make current Val{} methods forward to it everywhere it is applicable, after careful examination of each case. This would remove a lot of redudant code from this package, and make it easier to maintain, as you suggest.

@devmotion:

I think it would. I think the only keyword left should be a typical input x.

So the only use case for this is preparation? I will need some time to look into DI code to see what it does exactly: does it need a type (like a Vector or SVector, does the distinction matter), or a "typical" value, or something else? I am asking because LogDensityProblems can supply some of that, ie problems know their input length.

I need some time to read up on this, I will be away from my computer for the weekend but I will get back to this topic on Tuesday.

Suggestions welcome. @gdalle, I appreciate your work a lot on DI and want to move forward with this, but I need to understand the details so that we can make a smooth transition, and for that I need time.

I expect that this package is not fully replaceable by DI, as it does a few extra things (again, a "problem" defined through this API knows about its dimension and AD capabilities), but I agree that we should remove redundancies.

gdalle commented 1 month ago

add an ADgradient method that implements via DI. It should not dispatch on ADtypes though, the user should indicate that they want DI specifically. It is my expectation that in the long run, calling ADgradient on ADtypes directly will dispatch to this method, but I want to keep this level of indirection. We can work out the syntax, suggestions welcome.

Fair enough! How about the following toggle?

ADgradient(backend::AbstractADType, l, ::Val{DI}=Val(false); kwargs...) where {DI}

So the only use case for this is preparation?

DI's design is interesting because preparation is effectively unlimited. We can put whatever we want in the prep object, as long as it speeds up gradient computations on similar inputs down the road. So we only need this one "use case" to cover everything the backends do: ForwardDiff configs, ReverseDiff tapes, FiniteDiff caches, Enzyme duplicated buffers, and so on. See examples in the DI tutorial.

does it need a type (like a Vector or SVector, does the distinction matter), or a "typical" value, or something else?

It needs an actual value, because things like the size of the vector are also important (and they are usually not part of the type). You can read more about the preparation system in the DI docs.

tpapp commented 1 month ago

@gdalle: I have read the DI docs and skimmed the source code. First, kudos on trying to organize all DI approaches into a coherent interface, it is a huge undertaking but should have a large payoff for the ecosystem in the long run.

I have some preliminary thoughs wrt to the interface of LogDensityProblems and DI.

First, in LogDensityProblems, the interface is focused on being functional:

  1. the argument can be assumed to have no state (unless explicitly requested, cf #8),

  2. it can be called with arbitrary xs as long as they are AbstractVector{<:Real}, and the implementations have complete freedom. Calls are not enforced to be consistent, you can call it one moment with a Vector{Float64}, then an SVector{3,Float32}, etc (cf #3).

The interface has no API to handle situations when the caller promises to use the same argument types, or values, in exchange for a potential speed benefit.

I am entertaining the idea that we should expose "preparation" in the API (as defined by in the main interface package, LogDensityProblems.jl), where the caller promises to call the problem with the same argument type over and over, in exchange for speedups, and maybe preallocate stuff. The API should allow for querying the argument type above and whether the object is mutable (thread safety).

Once we implement that, we can flesh out the AD interface using DI and that API. That is to say, preparation would not be exposed via DI, but our own API that forwards to DI.

I am still thinking about the details but this is the general direction I am considering; I need to also understand sparse coloring its relation to preparation.

gdalle commented 1 month ago

Calls are not enforced to be consistent,

This is a big difference indeed, and I understand why you would want to change your interface to accommodate it. Note that, at the moment, some backends already perform preparation when you pass x, so I'm not sure what actually happens when you change the input type?

I need to also understand sparse coloring its relation to preparation.

Coloring is not relevant for gradients because a gradient is always dense (or you have some useless inputs) and can be computed in O(1) function calls-equivalents. Sparse AD is only useful when matrices are returned (Jacobians and Hessians).

gdalle commented 2 weeks ago

Sorry I merged the review suggestions without checking typos, fixed now. The tests pass locally.

tpapp commented 2 weeks ago

@devmotion, @gdalle: would a minor version bump be OK for this? After all, we just add new features, even though the change is extensive.

devmotion commented 2 weeks ago

Yes, I think a minor release is appropriate here.