FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.47k stars 210 forks source link

Difference in allocations of gradient() over structurally identical functions #1206

Open ancapdev opened 2 years ago

ancapdev commented 2 years ago

I've distilled this from a real world use case where this is causing me a 4-fold increase in allocations from what appears to be the optimal. I'm working to scale up training of small models over multiple threads, and reducing heap pressure, so this difference is very material.

On first use of some loss function of some structure it appears that this type is forever cursed with additional allocations. I would guess this is something to do with the internal machinery of IRTools and Zygote, and bit beyond my expertise. Hoping someone would be able to help diagnose the root cause and find a robust workaround or fix.

I've tried to eliminate world age as the source of the problem, e.g., by re-running from top-level or using Base.invokelatest(). I've verified gradients are identical, so it appears both variants are functionally the same.

using BenchmarkTools
using Flux
using MLUtils

function setup1(X, Y, model)
    loss = let m = model
        (x, y) -> Flux.mse(m(x), y)
    end
    data = DataLoader((X, Y), batchsize = 2)
    (; model, loss, data)
end

function setup2(X, Y, model)
    loss = let m = model
        (x, y) -> Flux.mse(m(x), y)
    end
    data = DataLoader((X, Y), batchsize = 2)
    (; model, loss, data)
end

function train(state)
    grads = []
    let ps = Flux.params(state.model)
        for d in state.data
            gs = gradient(ps) do
                state.loss(d...)
            end
            push!(grads, gs.grads)
        end
    end
end

N = 10
X = rand(Float32, 8, N)
Y = rand(Float32, 1, N)
model = Dense(8 => 1)

state1 = setup1(X, Y, model)
println("initial, world: $(Base.get_world_counter())")
@btime train($state1)

println("re-run, world: $(Base.get_world_counter())")
@btime train($state1)

state2 = setup2(X, Y, model)
println("setup structurally identical new types, world: $(Base.get_world_counter())")
@btime train($state2)

state1 = setup1(X, Y, model)
println("re-setup initial, world: $(Base.get_world_counter())")
@btime train($state1)

println("check functionally identitical: $(train(state1) == train(state2))")

Output

initial, world: 31674
  25.659 μs (419 allocations: 55.69 KiB)
re-run, world: 31676
  25.318 μs (419 allocations: 55.69 KiB)
setup structurally identical new types, world: 31678
  19.307 μs (329 allocations: 41.78 KiB)
re-setup initial, world: 31680
  25.990 μs (419 allocations: 55.69 KiB)
check functionally identitical: true

I've check allocation differences are indeed from gradient() with:

function train2(state)
    allocs = 0
    let ps = Flux.params(state.model)
        for d in state.data
            allocs += @allocated gradient(ps) do
                state.loss(d...)
            end
        end
    end
    allocs
end

Yielding 52800 bytes for state1 and 38880 bytes for state2.

ToucheSir commented 2 years ago

A couple things to confirm: 1. switching the call order of setup1 and setup2 doesn't change anything? 2. Warming up the gradient and loss function for each setup function before running @btime doesn't change anything?

  1. What does the allocation profiler say? I think a side-by-side comparison would be interesting.
ToucheSir commented 2 years ago

Reduced:

using BenchmarkTools
using Zygote

setup1() = x -> exp.(x)
setup2() = x -> exp.(x)

function train(loss, data)
  _, back = pullback(loss, data)
  back(1)
  return
end

X = (1,)

loss1 = Base.splat(setup1())
println("initial, world: $(Base.get_world_counter())")
@btime train($loss1, $X)

println("re-run, world: $(Base.get_world_counter())")
@btime train($loss1, $X)

loss2 = Base.splat(setup2())
println("setup structurally identical new types, world: $(Base.get_world_counter())")
@btime train($loss2, $X)

loss1 = Base.splat(setup1())
println("re-setup initial, world: $(Base.get_world_counter())")
@btime train($loss1, $X)

Timings:

initial, world: 31482
  1.129 μs (29 allocations: 816 bytes)
re-run, world: 31484
  1.120 μs (28 allocations: 800 bytes)
setup structurally identical new types, world: 31486
  11.829 ns (0 allocations: 0 bytes)
re-setup initial, world: 31488
  1.136 μs (28 allocations: 800 bytes)

Factors which appear to make a difference:

  1. Creating an inner anonymous function for the loss (even if it doesn't capture anything)
  2. Splatting in the top-level loss closure
  3. Broadcasting a non-trivial unary function (tested abs2, sin and exp, identity and the binary x .* 1 failed) in the loss function itself
ToucheSir commented 2 years ago

Ok, this appears to be going beyond the scope of what Zygote can handle and into core compiler territory. If you run both functions through, code_adjoint, they are identical. So are the loss functions themselves down to at least the typed IR level. However, we already have a divergence in the type-inferred IR of the gradient code in train:

Loss 1:

MethodInstance for train(::Base.var"#86#87"{var"#1#2"}, ::Tuple{Int64})
  from train(loss, data) in Main at /home/brianc/projects/juliamwes/fluxstuff/1206.jl:8
Arguments
  #self#::Core.Const(train)
  loss::Core.Const(Base.var"#86#87"{var"#1#2"}(var"#1#2"()))
  data::Tuple{Int64}
Locals
  @_4::Int64
  back::ZYGOTE.VAR"#56#57"
Body::Nothing
1 ─ %1 = Main.pullback(loss, data)::TUPLE{ANY, ZYGOTE.VAR"#56#57"}
│   %2 = Base.indexed_iterate(%1, 1)::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])
│        Core.getfield(%2, 1)
│        (@_4 = Core.getfield(%2, 2))
│   %5 = Base.indexed_iterate(%1, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#56#57", Int64}, Any[Zygote.var"#56#57", Core.Const(3)])
│        (back = Core.getfield(%5, 1))
│        (back)(1)
└──      return nothing

Loss 2:

MethodInstance for train(::Base.var"#86#87"{var"#3#4"}, ::Tuple{Int64})
  from train(loss, data) in Main at /home/brianc/projects/juliamwes/fluxstuff/1206.jl:8
Arguments
  #self#::Core.Const(train)
  loss::Core.Const(Base.var"#86#87"{var"#3#4"}(var"#3#4"()))
  data::Tuple{Int64}
Locals
  @_4::Int64
  back::Zygote.var"#56#57"{typeof(∂(#86))}
Body::Nothing
1 ─ %1 = Main.pullback(loss, data)::Tuple{Float64, Zygote.var"#56#57"{typeof(∂(#86))}}
│   %2 = Base.indexed_iterate(%1, 1)::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(2)])
│        Core.getfield(%2, 1)
│        (@_4 = Core.getfield(%2, 2))
│   %5 = Base.indexed_iterate(%1, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#56#57"{typeof(∂(#86))}, Int64}, Any[Zygote.var"#56#57"{typeof(∂(#86))}, Core.Const(3)])
│        (back = Core.getfield(%5, 1))
│        (back)(1)
└──      return nothing

I can confirm the difference in generated code is even larger once we get down to typed IR or native. Note that the relative order of setup1 and setup2 does not appear to matter anywhere, just that they are distinct functions.

Edit: updated with an example that just uses scalar params. I'm not sure why some types are uppercased with this approach, maybe something to do with instability?

ToucheSir commented 2 years ago

One more observation: adding a line gradient(x -> exp.(x...), (1.0,)) before the calls to train eliminates any allocation overhead on either of the subsequent calls. Now I'm not entirely clear how inference works with nested anonymous functions, but I would assume x -> exp.(x...) is not treated identically to y -> (x -> exp.(x))(y...). And yet running one before the other seems to help, at the cost of "poisoning" the first function.

ancapdev commented 2 years ago

@ToucheSir thanks for looking into this and simplifying the reproducible. It seems pretty low level and hairy. Do you have any ideas for next steps or who may be better placed to get to the bottom of this?

ToucheSir commented 2 years ago

You could try raising this for Base Julia, but they may want a Zygote-free MWE. One idea for getting around that would be extracting the CodeInfo Zygote generates and either a) splicing it into a new function or b) writing function(s) that reproduce it.

Another lead that came up on Slack was that nested generated functions may be a culprit. Creating a MWE that roughly matches what Zygote does (not semantically, just structurally) could be simpler.