EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
458 stars 66 forks source link

Enzyme creating a large number of allocations during reverse pass #1680

Open ptiede opened 4 months ago

ptiede commented 4 months ago

This is a fun one. I am finding that for certain code patterns Enzyme is inserting a large number of allocations into the reverse pass even if the forward pass only does a small number or zero allocations. The MWE is

using Enzyme
using StructArrays
using StaticArrays
using BenchmarkTools

@inline function fga(x)
    return SMatrix{2,2,typeof(x.g), 4}(x.g, 0, 0, x.g)
end

getparams(x, inds) = map(getindex, x, inds)
getinds(xinds, index, ::Val{1}) = map(x->getindex(x[1], index), xinds)
getinds(xinds, index, ::Val{2}) = map(x->getindex(x[2], index), xinds)

function apply_instrument(vis, x, inds)
    vout = similar(vis, SMatrix{2,2,eltype(vis[1]), 4})
    @inbounds for i in eachindex(vout, vis)
        v = apply_jones(vis[i], i, x, inds)
        vout[i] = v
    end
    return vout
end

@inline inner(j1, j2, v) = j1*SMatrix{2,2,eltype(v), 4}(v[1], v[2], v[3], v[4])*j2'

@inline function apply_jones(v, index, x, inds)
    xind1 = getinds(inds, index, Val(1))
    xind2 = getinds(inds, index, Val(2))
    p1 = getparams(x, xind1)
    p2 = getparams(x, xind2)
    j1 = fga(p1)
    j2 = fga(p2)
    return inner(j1, j2, v)
end

function test(vis, x, inds)
    vout = apply_instrument(vis, x, inds)
    mapreduce(x->sum(abs2, x), +, vout)
end

function dtest(vis, dvis, x, dx, xinds)
    map(x->fill!(x, 0), dx)
    fill!(dvis, SVector(0.0, 0.0, 0.0, 0.0))
    autodiff(Enzyme.Reverse, test, Active, Duplicated(vis, dvis), Duplicated(x, dx), Const(xinds))
end

ngains = 186
nvis = 315

x = (g = rand(ngains),)

vis = StructVector{SVector{4, Float64}}((rand(nvis), rand(nvis), rand(nvis), rand(nvis)))

indg1 = rand(1:ngains, nvis)
indg2 = rand(1:ngains, nvis)
indd1 = rand(1:ngains, nvis)
indd2 = rand(1:ngains, nvis)

xinds = (g = (indg1, indg2), )

@benchmark test($vis, $x, $xinds)

dvis = zero(vis)
dx = Enzyme.make_zero(x)

@benchmark dtest($vis, $dvis, $x, $dx, $xinds)

which gives output

# Forward pass
BenchmarkTools.Trial: 10000 samples with 7 evaluations.
 Range (min … max):  4.806 μs …  17.213 ms  ┊ GC (min … max):  0.00% … 99.85%
 Time  (median):     5.336 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   7.733 μs ± 172.506 μs  ┊ GC (mean ± σ):  24.72% ±  2.25%

      ▄▇██▇▆▄▃▂▂▂▃▂▁                          ▂▂▃▃▂▃▃▄▃▂      ▂
  ▄▁▅███████████████▇▆▆▆▅▅▃▃▄▃▁▃▅▃▃▁▃▄▁▅▁▁▃▄▇▇████████████▇█▇ █
  4.81 μs      Histogram: log(frequency) by time      8.81 μs <

# Reverse pass
julia> @benchmark dtest($vis, $dvis, $x, $dx, $xinds)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  59.949 μs … 116.741 ms  ┊ GC (min … max):  0.00% … 99.78%
 Time  (median):     63.514 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   78.214 μs ±   1.169 ms  ┊ GC (mean ± σ):  16.12% ±  1.88%

       ▃█▅▂▁▂▁                                                  
  ▁▁▁▁▄███████▇▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁ ▂
  59.9 μs         Histogram: frequency by time         82.4 μs <

 Memory estimate: 57.80 KiB, allocs estimate: 42.

The problem isn't so much the performance as it is the allocations. Profiling the code I see the following

profile

So a large portion of the time is spent allocating when calling getindex. Interestingly if I inline apply_instrument the problem goes away and I get the dramatically improved performance

julia> @benchmark dtest($vis, $dvis, $x, $dx, $xinds)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   9.010 μs … 475.684 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):      9.840 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   10.143 μs ±   4.717 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

         ▁▄▆█▇▄▂                                                
  ▁▁▂▂▂▃▆████████▅▄▃▂▃▂▂▃▃▃▃▃▃▃▃▃▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  9.01 μs         Histogram: frequency by time         13.2 μs <

 Memory estimate: 21.00 KiB, allocs estimate: 8.

I am on Julia 1.10.4 and Enzyme 0.12.25

danielwe commented 4 months ago

Calling Enzyme.API.inlineall!(true) at the top of the program can reduce the allocations significantly in my experience

ptiede commented 3 months ago

That is good to know, and I can definitely add that! Are these allocations something I should expect with this kind of code?

ptiede commented 2 months ago

Ok so I also sporadically get a segfault from code very similar to this. If I change apply instrument to

@inline function apply_instrument(vis, J::ObservedInstrumentModel, x)
    vout = similar(vis, SMatrix{2,2,eltype(vis[1]), 4})
    xint = x.instrument
    vout .= apply_jones.(vis, eachindex(vis), Ref(J), Ref(xint))
    return vout
end

I get a gc corruption error when calling Ref, which sometime spits out a ton of info. An example is below.

err_gccorruption_enzyme_sep2.txt

The code to reproduce this is quite long, but the segfault always occurs in this part of the code. I am going to see if I can get a more reliable MWE.

wsmoses commented 2 months ago

@ptiede can you open that as an issue with a MWE and also confirm you’re on 1.10.5?

ptiede commented 2 months ago

Ya for sure! And ya I am on 1.10.5.