EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
458 stars 65 forks source link

`Enzyme.gradient` allocates on `SVector` #1968

Open gdalle opened 1 month ago

gdalle commented 1 month ago

Hi! As you know, @ExpandingMan and I are looking to optimize performance for StaticArrays. Forward mode works splendidly, but reverse mode still makes one allocation during the gradient call:

using StaticArrays, Enzyme, BenchmarkTools
f(x) = sum(abs2, x);
x = SVector(1.0, 2.0);
@btime Enzyme.gradient(Enzyme.Reverse, f, $x)  # 8.999 ns (1 allocation: 32 bytes)
@btime Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active($x))  # 4.218 ns (0 allocations: 0 bytes)

I found it surprising because Enzyme guesses the right activity for SVector:

Enzyme.guess_activity(typeof(x), Enzyme.Reverse)  # Active{SVector{2, Float64}}

The allocation happens on the following line: https://github.com/EnzymeAD/Enzyme.jl/blob/42ecd12cf5076f8d3db1694e014f69bc0b99173f/src/Enzyme.jl#L1708 From what I understand, the generated function Enzyme.gradient puts a Ref there to treat every argument as (Mixed)Duplicated. This means that all gradient results are stored in the passed arguments: https://github.com/EnzymeAD/Enzyme.jl/blob/42ecd12cf5076f8d3db1694e014f69bc0b99173f/src/Enzyme.jl#L1741 Otherwise, you would have to recover some gradients from the result and others from the arguments, which is understandably tricky. Do you think there is an easy fix in Enzyme? Otherwise, since DI only has one differentiated argument, I assume it will be rather straightfoward to call Enzyme.autodiff directly inside DI.gradient and recover allocation-free behavior.

Related:

wsmoses commented 1 month ago

sure, PR welcome!

gdalle commented 1 month ago

Sure! I'll try to handle this case correctly in DI first, because it still errors at the moment. Once I have a handle on the single-argument solution, I'll try to tamper with the generated function to do the same for multiple arguments.

wsmoses commented 3 weeks ago

bump @gdalle

wsmoses commented 2 weeks ago

gentle ping @gdalle

gdalle commented 2 weeks ago

Essentially this comes down to adding the option for Active inputs here:

https://github.com/EnzymeAD/Enzyme.jl/blob/42ecd12cf5076f8d3db1694e014f69bc0b99173f/src/Enzyme.jl#L1714-L1720

The variable interpolated as $arg is a boolean defined as follows, which is a bit obscure to me, care to shed some light?

https://github.com/EnzymeAD/Enzyme.jl/blob/42ecd12cf5076f8d3db1694e014f69bc0b99173f/src/Enzyme.jl#L1686-L1692