FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Increasing memory usage in each call of gradient #1509

Open CarloLucibello opened 7 months ago

CarloLucibello commented 7 months ago

I was experimenting with alternatives to https://github.com/FluxML/Optimisers.jl/pull/57 when I encountered the following weird issue.

Look at the number of allocations when computing the gradient of loss1

function loss1(m)
    ls = 0f0
    for l in Functors.fleaves(m)
        if l isa AbstractArray{<:Number}
            ls += sum(l)
        end
    end
    return ls
end

function loss2(m)
    sum(sum(l) for l in Functors.fleaves(m) if l isa AbstractArray{<:Number})
end

function loss3(m)
    sum([sum(l) for l in Functors.fleaves(m) if l isa AbstractArray{<:Number}])
end

function perf()
    m = Chain(Dense(128 => 128, relu), BatchNorm(3), Dense(128 => 10))
    @btime gradient(loss1, $m)[1]
    @btime gradient(loss2, $m)[1]
    @btime gradient(loss3, $m)[1]
    println()
end

perf(); #1st call
perf(); #2nd call
perf(); #3rd call
# OUTPUT
154.795 ms (1022652 allocations: 39.16 MiB)
1.734 ms (7605 allocations: 352.62 KiB)
1.314 ms (5948 allocations: 288.08 KiB)

258.556 ms (1658450 allocations: 63.37 MiB)
1.735 ms (7605 allocations: 352.62 KiB)
1.316 ms (5948 allocations: 288.08 KiB)

336.418 ms (2154374 allocations: 82.29 MiB)
1.739 ms (7605 allocations: 352.62 KiB)
1.319 ms (5948 allocations: 288.08 KiB)

What's going on?

ToucheSir commented 7 months ago

I've seen this before with code that uses Functors (e.g. params), but not all functions that do. Still not sure why.