FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Memory leak #1326

Open freddycct opened 2 years ago

freddycct commented 2 years ago

Package Version

v0.6.49

Julia Version

1.8.2

OS / Environment

OSX Apple Silicon

Describe the bug

Memory usage keeps going up despite the use of GC.gc()

Steps to Reproduce

@everywhere begin
    using Zygote
    using Functors
    using Optimisers
end

@everywhere begin 
    struct NodeLayer
        w₁::Matrix{Float32}
        b₁::Vector{Float32}
        w₂::Matrix{Float32}
        b₂::Float32
    end
    Functors.@functor NodeLayer

    NodeLayer(K::Int) = NodeLayer(randn(Float32, (K,2*K)), randn(Float32, K), randn(Float32, (1,K)), randn(Float32))

    function (nl::NodeLayer)(x₁::AbstractArray, x₂::AbstractArray)
        h₀ = vcat(x₁, x₂)
        h = nl.w₁ * h₀ .+ nl.b₁
        y = nl.w₂ * h .+ nl.b₂ 
        return y[1], h
    end

    abstract type Node end

    struct Leaf <: Node
        i::Int
        x::Vector{Float32}
    end

    function (n::Leaf)(θ::Vector{NodeLayer})::Tuple{Float32, Vector{Float32}}
        nl = θ[n.i]
        return nl(n.x, n.x)
    end

    struct Branch <: Node
        i::Int
        left::Node
        right::Node
    end

    function (n::Branch)(θ::Vector{NodeLayer})::Tuple{Float32, Vector{Float32}}
        y₁, h₁ = n.left(θ)
        y₂, h₂ = n.right(θ)

        nl = θ[n.i]
        y₃, h₃ = nl(h₁, h₂)
        y = y₁ + y₂ + y₃
        return y, h₃
    end

    # large p create shallow trees
    function genTree(N::Int, K::Int, d::Int)::Node
        if d == 1
            return Leaf(rand(1:N), rand(Float32, K))
        else
            return Branch(rand(1:N), genTree(N, K, rand(1:d-1)), genTree(N, K, rand(1:d-1)))
        end
    end

    function loss(t1::Node, t2::Node, θ::Vector{NodeLayer})
        return sqrt((t1(θ)[1] - t2(θ)[1])^2)
    end
end

function main()
    M = 64 # number of tasks

    @everywhere begin
        N = 20 # size of the parameters
        D = 5 # depth of the trees
        K = 32
    end

    # these parameters make up the model returned by genTree
    θ = map(x->NodeLayer(K), 1:N)

    optRule = Optimisers.Adam(1f-3)
    optState = Optimisers.setup(optRule, θ) # initialize the states for the Adam optimizer

    println("start training")
    epoch = 1
    while true
        if epoch >= 100
            break
        end
        stats = @timed begin
            ret = let θ = θ
                pmap(1:M) do x # distributed=false also cause memory leak
                    local t₁ = genTree(N, K, D)
                    local t₂ = genTree(N, K, D)

                    local ll, grads = Zygote.withgradient(θ) do θ # something in here is not compatible with asyncmap
                        loss(t₁, t₂, θ)
                    end
                    return ll, grads[1]
                end
            end

            totalLoss = mapreduce(x->x[1], +, ret)
            local grads = Zygote.accum(map(x->x[2], ret)...)
            optState, θ = Optimisers.update!(optState, θ, grads)

            totalLoss
        end
        totalLoss = stats.value
        timeTaken = stats.time
        println("$(epoch)/∞: totalLoss = $(totalLoss), time taken = $(timeTaken)")
        flush(stdout)
        epoch += 1

        @everywhere GC.gc()
    end
    println("end training")
end

main()
julia -p 4 <filename>

Expected Results

Memory usage should hit a ceiling in worker nodes and stay there.

Observed Results

Memory usage keeps climbing in worker nodes.

Relevant log output

No response

CarloLucibello commented 1 year ago

@freddycct it would be nice if you try to reduce the example as much as possible to make it easier to identify the offending part