gdalle / DifferentiationInterface.jl

An interface to various automatic differentiation backends in Julia.
https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface
MIT License
169 stars 13 forks source link

Make Enzyme dispatches compatible with closures #339

Closed ChrisRackauckas closed 1 month ago

ChrisRackauckas commented 3 months ago

In the Enzyme setups https://github.com/gdalle/DifferentiationInterface.jl/blob/main/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L13 it looks like you're using the raw f. This omits the handling of any memory associated with caches, particularly within closures. To fix this is rather straightforward though, you can just copy SciMLSensitivity. You just do a duplicated on the f https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/derivative_wrappers.jl#L697 where the duplicated part is just an f_cache = Enzyme.make_zero(f) copy. To make this safe for repeated application, you need to add a call Enzyme.make_zero!(f_cache) so its duplicated values are always zero if you reuse it.

gdalle commented 3 months ago

@wsmoses said this was probably a bad idea due to performance degradation, so I'm leaving the PR #341 closed for now. Are there other solutions?

ChrisRackauckas commented 3 months ago

Well the other option is incorrectness or just erroring if caches are used, I don't see how that's better?

wsmoses commented 3 months ago

I mean honestly this is where activity info/multi arg is critical.

if you have a closure (which is is required by DI atm), then you'll end up differentiating every var in the original fn. So if you have something like

NN = complex neural network
DI.gradient(AutoEnzyme(), x->NN() + x, 3.1)

you'll now be forced to AD the entire neural network as opposed to the one scalar. In this case leading an O(1) derivative being unboundedly worse. Without the ability to handle multiple args/activity, DI would be forced to AD through the whole NN if the closure were marked active.

wsmoses commented 3 months ago

Frankly, this is where I'd say it makes sense for DI to figure out how it and/or AD.jl wants to handle multiple args, use direct Enzyme autodiff calls which don't have such limitations for now, revisiting this question later.

gdalle commented 3 months ago

I'm slowly getting a clearer picture of how I can pull it off. But the initial plan was for AbstractDifferentiation to handle multiple arguments, so I wanna wait for @mohamed82008's approval before I dive into it within DI.

ChrisRackauckas commented 3 months ago

Even if DI handles multiple arguments though, you'd still want to duplicate the function because if you don't handle any enclosed caches correctly you can get incorrect derivatives, so I don't see why this would wait. Indeed the downside is that you always have to assume that all caches can be differentiated, and this is then a good reason to allow for multiple arguments so you can Const some, but my point is that if we want DI to actually be correct then we do need to enforce the differentiation of enclosed variables carries forward their derivative values.

ChrisRackauckas commented 3 months ago

It at least needs to be an option, AutoEnzyme(duplicate_function = Val(true)) by default, but can be Val(false) as an optimization if someone wants to forcibly Const all enclosed values (at their own risk). If someone has no enclosed values there's no overhead, and if they are non-const then the default is correct, so it's just a performance optimization so I'd leave that as a user toggle. Adding that to ADTypes would be good for SciMLSensitivity as well as we'd do the same in implementation.

wsmoses commented 3 months ago

My point about support for multiple arguments and/or activity, is that they would potentially remedy the performance issue in my example.

if DI supported specifying the function as const/duplicated [aka activity] the problem is trivially remedied.

In the alternative, if multiple arguments were supported [perhaps with a Const input], you could pass the NN and/or closure data in it and again avoid the issue.

ChrisRackauckas commented 3 months ago

I don't disagree with that. My point though is that even if DI makes all of the inputs arguments, the default activity on a function would likely be const unless the documentation showed people how to do this. I don't think that's the right default for DI since then many common Julia functions would give wrong values. You'd basically have to say, don't pass f, the interface is Duplicated(f, make_zero(f)). My point is that shouldn't be left to the user of DI who should expect that the simple thing is correct, and if DI.gradient(f, x) is wrong because they need to DI.gradient(Duplicated(f, make_zero(f)), x) otherwise they drop derivatives on enclosed caches, I would think something has gone wrong with the interface. My suggestion is to just via AutoEnzyme make the assumption that's required, which is still optimal in the case that there are no caches, but yes is effectively a safety copy done to make caching functions work out of the box, but with an option to turn it off at their own risk.

But also, DI shouldn't wait until multi-arg activities are supported before doing any of this. Otherwise it will have issues with user-written closures until multi-arg activities, which arguably is a pretty nasty bug that requires a hotfix. It does mean that yes constants enclosed in functions will slow things down a bit because you'll differentiate more than you need to, but it also means that enclosed cache variables will correctly propagate derivatives which is more important to a high level interface.

I didn't test this exactly, but I would think an MWE would be as simple as:

a = [1.0]
function f(x)
  a[1] = 1.0
  a[1] += x
  a[1]^2
end

would give an incorrect derivative with DI without this, which to me is a red flag that needs to be fixed. And then we can argue when the multi-arg form comes whether the user needs to enable the fix or whether the fix comes enabled by default, but I don't think we should wait to make this work.

And to be clear, I don't think Enzyme's interface should do this, but Enzyme is a much lower level utility targeting a different level of user.

gdalle commented 3 months ago

I tend to agree with Chris on this one. Until I add activities or multiple arguments, better safe and slow than fast and wrong.

wsmoses commented 3 months ago

I see what you're saying, but I still feel like this is an edge case that is more likely to cause problems for users than fixes.

In particular, the only sort of case where this is needed is where you read and write to a captured buffer. With the exception of preallocation tools code, this is immensely rare in Julia functions you want to AD [largely due to non-mutation, especially non-mutation of closures].

However, by marking the entire closure as duplicated, you now need enzyme to successfully differentiate all closure operations, including those where this read and write to capured buffer doesn't apply. If there's a function currently unhandled by Enzyme you'll error with the duplicated fn, whereas marking it const would succeed.

To be clear, I see the arguments for both sides of this, but I'm wondering what is the better trade off to make.

wsmoses commented 3 months ago

Honestly, given that I'm doubtful of much code outside of preallocationtools that would have this apply, I wonder if it make sense to just add a preallocationtools mode to DI [which may be separately useful in its own right]

ChrisRackauckas commented 3 months ago

In particular, the only sort of case where this is needed is where you read and write to a captured buffer. With the exception of preallocation tools code, this is immensely rare in Julia functions you want to AD [largely due to non-mutation, especially non-mutation of closures].

That's not really the case though. It's not rare. It's actually very common and explicitly mentioned in the documentation of many packages and tutorials that one should write non-allocating code. Here is one of many examples of that:

https://docs.sciml.ai/DiffEqDocs/stable/tutorials/faster_ode_example/#Example-Accelerating-Linear-Algebra-PDE-Semi-Discretization

Such functions are made to be fully mutating and non-allocating, and also fully type-stable, and so perfectly within the realm of Enzyme. And these functions will not error but give the wrong answer if the closure is not duplicated, which is not the nicest behavior.

I think you're thinking specifically about Flux using functors where it's effectively allocating type-unstable functional code carrying around parameters in its objects which may not need to be differentiated. Flux is the weird one, not everything else. I actually can't think of another library that is engineered similarly to Flux, while most scientific models, PDE solvers, etc. are engineered similarly to the example I have up there where pre-allocated buffers are either passed around or enclosed and then used for getting a allocation-free runtime. And in any case, I'd argue it should be the Flux example to opt-out of duplicating the closure as a performance improvement, not the scientific models, PDE solvers, etc. opting into duplicating the function in order to ensure they get the right gradient value on repeated applications with caches.

wsmoses commented 3 months ago

Hm, I’m not seeing the read and write to a captured buffer. It’s reading or writing to an argument for sure which isn’t impacted here.

Mind pasting the example you’re thinking of?

On Tue, Jul 2, 2024 at 3:41 AM Christopher Rackauckas < @.***> wrote:

In particular, the only sort of case where this is needed is where you read and write to a captured buffer. With the exception of preallocation tools code, this is immensely rare in Julia functions you want to AD [largely due to non-mutation, especially non-mutation of closures].

That's not really the case though. It's not rare. It's actually very common and explicitly mentioned in the documentation of many packages and tutorials that one should write non-allocating code. Here is one of many examples of that:

https://docs.sciml.ai/DiffEqDocs/stable/tutorials/faster_ode_example/#Example-Accelerating-Linear-Algebra-PDE-Semi-Discretization

Such functions are made to be fully mutating and non-allocating, and also fully type-stable, and so perfectly within the realm of Enzyme. And these functions will not error but give the wrong answer if the closure is not duplicated, which is not the nicest behavior.

I think you're thinking specifically about Flux using functors where it's effectively allocating type-unstable functional code carrying around parameters in its objects which may not need to be differentiated. Flux is the weird one, not everything else. I actually can't think of another library that is engineered similarly to Flux, while most scientific models, PDE solvers, etc. are engineered similarly to the example I have up there where pre-allocated buffers are either passed around or enclosed and then used for getting a allocation-free runtime. And in any case, I'd argue it should be the Flux example to opt-out of duplicating the closure as a performance improvement, not the scientific models, PDE solvers, etc. opting into duplicating the function in order to ensure they get the right gradient value on repeated applications with caches.

— Reply to this email directly, view it on GitHub https://github.com/gdalle/DifferentiationInterface.jl/issues/339#issuecomment-2201719700, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXBEWU4YGUU54PFS76DZKIHMRAVCNFSM6AAAAABJ7W6EQKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMBRG4YTSNZQGA . You are receiving this because you were mentioned.Message ID: @.***>

ChrisRackauckas commented 3 months ago
Ayu = zeros(N, N)
uAx = zeros(N, N)
Du = zeros(N, N)
Ayv = zeros(N, N)
vAx = zeros(N, N)
Dv = zeros(N, N)
function gm3!(dr, r, p, t)
    a, α, ubar, β, D1, D2 = p
    u = @view r[:, :, 1]
    v = @view r[:, :, 2]
    du = @view dr[:, :, 1]
    dv = @view dr[:, :, 2]
    mul!(Ayu, Ay, u)
    mul!(uAx, u, Ax)
    mul!(Ayv, Ay, v)
    mul!(vAx, v, Ax)
    @. Du = D1 * (Ayu + uAx)
    @. Dv = D2 * (Ayv + vAx)
    @. du = Du + a * u * u ./ v + ubar - α * u
    @. dv = Dv + a * u * u - β * v
end
prob = ODEProblem(gm3!, r0, (0.0, 0.1), p)
@btime solve(prob, Tsit5());
wsmoses commented 3 months ago

Okay yeah that does have it.

If it is indeed common, I think the scales weigh towards the duplicated then.

However I will say that globals/captured vars like this are likely poor for performance in contrast to being passed as an argument and separately these examples (and perhaps a perf guide) should be updated

On Tue, Jul 2, 2024 at 10:17 AM Christopher Rackauckas < @.***> wrote:

Ayu = zeros(N, N) uAx = zeros(N, N) Du = zeros(N, N) Ayv = zeros(N, N) vAx = zeros(N, N) Dv = zeros(N, N)function gm3!(dr, r, p, t) a, α, ubar, β, D1, D2 = p u = @view r[:, :, 1] v = @view r[:, :, 2] du = @view dr[:, :, 1] dv = @view dr[:, :, 2] mul!(Ayu, Ay, u) mul!(uAx, u, Ax) mul!(Ayv, Ay, v) mul!(vAx, v, Ax) @. Du = D1 (Ayu + uAx) @. Dv = D2 (Ayv + vAx) @. du = Du + a u u ./ v + ubar - α u @. dv = Dv + a u u - β vend prob = ODEProblem(gm3!, r0, (0.0, 0.1), @.*** solve(prob, Tsit5());

— Reply to this email directly, view it on GitHub https://github.com/gdalle/DifferentiationInterface.jl/issues/339#issuecomment-2202474610, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXDHU3NJ5TKNQIBVV6LZKJVYZAVCNFSM6AAAAABJ7W6EQKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMBSGQ3TINRRGA . You are receiving this because you were mentioned.Message ID: @.***>

gdalle commented 2 months ago

See https://github.com/SciML/SciMLBenchmarks.jl/pull/988 for a more involved discussion

gdalle commented 2 months ago

The first ingredient of the solution is available in the latest release of ADTypes with AutoEnzyme(constant_function=true/false). Now it's on me to implement both variants here

gdalle commented 2 months ago

@willtebbutt what kind of assumptions does Tapir make vis-a-vis constant functions?

willtebbutt commented 2 months ago

At present, Tapir.jl assumes that all arguments are active, and differentiates through everything. Consequently I'm reasonably confident that there's nothing here that is relevant to Tapir.jl.