Open ancapdev opened 2 years ago
A couple things to confirm:
1. switching the call order of setup1 and setup2 doesn't change anything?
2. Warming up the gradient and loss function for each setup function before running @btime
doesn't change anything?
Reduced:
using BenchmarkTools
using Zygote
setup1() = x -> exp.(x)
setup2() = x -> exp.(x)
function train(loss, data)
_, back = pullback(loss, data)
back(1)
return
end
X = (1,)
loss1 = Base.splat(setup1())
println("initial, world: $(Base.get_world_counter())")
@btime train($loss1, $X)
println("re-run, world: $(Base.get_world_counter())")
@btime train($loss1, $X)
loss2 = Base.splat(setup2())
println("setup structurally identical new types, world: $(Base.get_world_counter())")
@btime train($loss2, $X)
loss1 = Base.splat(setup1())
println("re-setup initial, world: $(Base.get_world_counter())")
@btime train($loss1, $X)
Timings:
initial, world: 31482
1.129 μs (29 allocations: 816 bytes)
re-run, world: 31484
1.120 μs (28 allocations: 800 bytes)
setup structurally identical new types, world: 31486
11.829 ns (0 allocations: 0 bytes)
re-setup initial, world: 31488
1.136 μs (28 allocations: 800 bytes)
Factors which appear to make a difference:
abs2
, sin
and exp
, identity
and the binary x .* 1
failed) in the loss function itselfOk, this appears to be going beyond the scope of what Zygote can handle and into core compiler territory. If you run both functions through, code_adjoint
, they are identical. So are the loss functions themselves down to at least the typed IR level. However, we already have a divergence in the type-inferred IR of the gradient code in train
:
Loss 1:
MethodInstance for train(::Base.var"#86#87"{var"#1#2"}, ::Tuple{Int64})
from train(loss, data) in Main at /home/brianc/projects/juliamwes/fluxstuff/1206.jl:8
Arguments
#self#::Core.Const(train)
loss::Core.Const(Base.var"#86#87"{var"#1#2"}(var"#1#2"()))
data::Tuple{Int64}
Locals
@_4::Int64
back::ZYGOTE.VAR"#56#57"
Body::Nothing
1 ─ %1 = Main.pullback(loss, data)::TUPLE{ANY, ZYGOTE.VAR"#56#57"}
│ %2 = Base.indexed_iterate(%1, 1)::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])
│ Core.getfield(%2, 1)
│ (@_4 = Core.getfield(%2, 2))
│ %5 = Base.indexed_iterate(%1, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#56#57", Int64}, Any[Zygote.var"#56#57", Core.Const(3)])
│ (back = Core.getfield(%5, 1))
│ (back)(1)
└── return nothing
Loss 2:
MethodInstance for train(::Base.var"#86#87"{var"#3#4"}, ::Tuple{Int64})
from train(loss, data) in Main at /home/brianc/projects/juliamwes/fluxstuff/1206.jl:8
Arguments
#self#::Core.Const(train)
loss::Core.Const(Base.var"#86#87"{var"#3#4"}(var"#3#4"()))
data::Tuple{Int64}
Locals
@_4::Int64
back::Zygote.var"#56#57"{typeof(∂(#86))}
Body::Nothing
1 ─ %1 = Main.pullback(loss, data)::Tuple{Float64, Zygote.var"#56#57"{typeof(∂(#86))}}
│ %2 = Base.indexed_iterate(%1, 1)::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(2)])
│ Core.getfield(%2, 1)
│ (@_4 = Core.getfield(%2, 2))
│ %5 = Base.indexed_iterate(%1, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#56#57"{typeof(∂(#86))}, Int64}, Any[Zygote.var"#56#57"{typeof(∂(#86))}, Core.Const(3)])
│ (back = Core.getfield(%5, 1))
│ (back)(1)
└── return nothing
I can confirm the difference in generated code is even larger once we get down to typed IR or native. Note that the relative order of setup1
and setup2
does not appear to matter anywhere, just that they are distinct functions.
Edit: updated with an example that just uses scalar params. I'm not sure why some types are uppercased with this approach, maybe something to do with instability?
One more observation: adding a line gradient(x -> exp.(x...), (1.0,))
before the calls to train
eliminates any allocation overhead on either of the subsequent calls. Now I'm not entirely clear how inference works with nested anonymous functions, but I would assume x -> exp.(x...)
is not treated identically to y -> (x -> exp.(x))(y...)
. And yet running one before the other seems to help, at the cost of "poisoning" the first function.
@ToucheSir thanks for looking into this and simplifying the reproducible. It seems pretty low level and hairy. Do you have any ideas for next steps or who may be better placed to get to the bottom of this?
You could try raising this for Base Julia, but they may want a Zygote-free MWE. One idea for getting around that would be extracting the CodeInfo Zygote generates and either a) splicing it into a new function or b) writing function(s) that reproduce it.
Another lead that came up on Slack was that nested generated functions may be a culprit. Creating a MWE that roughly matches what Zygote does (not semantically, just structurally) could be simpler.
I've distilled this from a real world use case where this is causing me a 4-fold increase in allocations from what appears to be the optimal. I'm working to scale up training of small models over multiple threads, and reducing heap pressure, so this difference is very material.
On first use of some loss function of some structure it appears that this type is forever cursed with additional allocations. I would guess this is something to do with the internal machinery of
IRTools
andZygote
, and bit beyond my expertise. Hoping someone would be able to help diagnose the root cause and find a robust workaround or fix.I've tried to eliminate world age as the source of the problem, e.g., by re-running from top-level or using
Base.invokelatest()
. I've verified gradients are identical, so it appears both variants are functionally the same.Output
I've check allocation differences are indeed from
gradient()
with:Yielding 52800 bytes for
state1
and 38880 bytes forstate2
.