FluxML / FluxTraining.jl

A flexible neural net training library inspired by fast.ai
https://fluxml.ai/FluxTraining.jl
MIT License
121 stars 27 forks source link

CUDA memory leak for Flux.Optimizer #148

Open RomeoV opened 1 year ago

RomeoV commented 1 year ago

(This issue has been moved here from https://github.com/FluxML/Flux.jl/issues/2261)

I have a somewhat complicated training setup and have recently started encountering CUDA-out-of-memory issues which only show up after a number of epochs.

I have managed to construct a minimum working example here:

using Flux
using FastAI
using MLUtils
using FastAI.FluxTraining

function main()
    DEVICE = gpu
    model = Chain(Dense(32*32*3=>2048), Dense(2048=>6), Dense(6, 32*32*3))

    make_data_sample_test(i) = (rand(Float32, 32*32*3),
                                rand(Float32, 32*32*3))
    data = mapobs(make_data_sample_test, 1:1024)
    dl     = DataLoader(data; batchsize=32, collate=true)

    loss = Flux.Losses.logitbinarycrossentropy
    opt = Flux.Adam(3e-4)
    learner = FastAI.Learner(model, loss;
                             optimizer=opt,
                             data=(dl, dl_val),
                             callbacks=[FluxTraining.ToGPU(), ])

    for _ in 1:5
      FluxTraining.epoch!(learner, FluxTraining.TrainingPhase())
      @show length(opt.state)
    end
end

After about 50 epochs (~1 minute on my laptop), I get an error that CUDA cannot allocate any more memory. This seems to be because in the optimizer, the state variable accumulates GPU Arrays over time.

The issue can be fixed by replacing opt = Flux.Adam() with opt = Optimizers.Adam(). However, I think we should fix the problem for the Flux optimizer, since it seems to be "officially" supported.

@DrChainsaw has suggested in the other issue that the problem is that the ToDevice callback is not applied to the optimizer parameters. However I haven't looked at the specifics, and how one would fix that. Any insights?

ToucheSir commented 1 year ago

I think this is the sequence of events which causes the leak:

  1. Once per epoch, the model is moved from CPU to GPU. This means the identity of the GPU model parameters will vary between epochs.
  2. Subsequently, the optimizer state is initialized from scratch based on the GPU model params, but only when using Optimisers.jl (because state is held externally to the optimization rules themselves). When using legacy Flux optimizers, the optimizer retains the now-obsolete state from the last epoch unchanged.
  3. When it comes time to update parameters, the state IdDict legacy optimizers use is expanded instead of updated as intended because the object identity of the params have changed.
  4. Rinse and repeat over multiple epochs.

There are a couple of ways we could address this, but I think it first raises a bigger question: why are we resetting the optimizer state at the beginning of each epoch in the first place? @lorenzoh do you remember the context for this decision?

mashu commented 1 month ago

I think it might be FluxTraining related because with following I get OOM error

using Flux
using MLUtils
using FluxTraining
using CUDA

DEVICE = gpu
model = Chain(Dense(32*32*3=>2048), Dense(2048=>6), Dense(6, 32*32*3))

make_data_sample_test(i) = (rand(Float32, 32*32*3),
                            rand(Float32, 32*32*3))
data = mapobs(make_data_sample_test, 1:1024)
dl     = DataLoader(data; batchsize=32, collate=true)

loss = Flux.Losses.logitbinarycrossentropy
opt = Flux.Adam(3e-4)
learner = Learner(model, loss;
                            optimizer=opt,
                            data=(dl, dl),
                            callbacks=[FluxTraining.ToGPU(), ])

for _ in 1:10000
    FluxTraining.epoch!(learner, FluxTraining.TrainingPhase())
    @show length(opt.state)
end

but not with Flux loop (even though model makes no sense but it's a good test case)

using Flux
using MLUtils
using FluxTraining
using CUDA

DEVICE = gpu
model = Chain(Dense(32*32*3=>2048), Dense(2048=>6), Dense(6, 32*32*3)) |> DEVICE

make_data_sample_test(i) = (rand(Float32, 32*32*3),
                            rand(Float32, 32*32*3))
data = mapobs(make_data_sample_test, 1:1024)
dl     = DataLoader(data; batchsize=32, collate=true)

loss = Flux.Losses.logitbinarycrossentropy
opt = Flux.Adam(3e-4)
opt_state = Flux.setup(opt, model)

for _ in 1:10000
    for batch in dl
        train, val = batch |> gpu
        λ,Δ = Flux.withgradient(model) do m
            ŷ = m(train)    
            l = loss(ŷ, val) 
        end
        Flux.update!(opt_state, model, Δ[1])
    end    
end

The code I mentioned to @darsnack is above. Also the fact that models have been working fine for over a year with CUDA (while not using FluxTraining) are probably another hint that it might not be CUDA problem or something very specific in CUDA that gets only triggered by FluxTraining.

mashu commented 1 month ago

@ToucheSir I just saw your response, sorry! So indeed moving data to GPU manually like below prevents the OOM from happening. So probably model should not be moved to GPU every epoch.

using Flux
using MLUtils
using FluxTraining
using CUDA

DEVICE = gpu
model = Chain(Dense(32*32*3=>2048), Dense(2048=>6), Dense(6, 32*32*3)) |> DEVICE

make_data_sample_test(i) = (rand(Float32, 32*32*3),
                            rand(Float32, 32*32*3))
data = mapobs(make_data_sample_test, 1:1024)
dl     = DataLoader(data; batchsize=32, collate=true)

loss = Flux.Losses.logitbinarycrossentropy
opt = Flux.Adam(3e-4)
learner = Learner(model, loss;
                            optimizer=opt,
                            data=(dl, dl))

struct MyTrainingPhase <: FluxTraining.AbstractTrainingPhase end
function  FluxTraining.step!(learner, phase::MyTrainingPhase, batch)
    xs, ys = batch |> DEVICE
    FluxTraining.runstep(learner, phase, (xs=xs, ys=ys)) do handle, state
        state.grads = gradient(learner.params) do
            state.ŷs = learner.model(state.xs)
            handle(FluxTraining.LossBegin())
            state.loss = learner.lossfn(state.ŷs, state.ys)
            handle(FluxTraining.BackwardBegin())
            return state.loss
        end
        handle(FluxTraining.BackwardEnd())
        Flux.update!(learner.optimizer, learner.params, state.grads)
    end
end
struct MyValidationPhase <: FluxTraining.AbstractValidationPhase end
function FluxTraining.step!(learner, phase::MyValidationPhase, batch)
    xs, ys = batch |> DEVICE
    FluxTraining.runstep(learner, phase, (xs=xs, ys=ys)) do _, state
        state.ŷs = learner.model(state.xs)
        state.loss = learner.lossfn(state.ŷs, state.ys)
    end
end
for epoch in 1:10000
    @info "Iteration $epoch"
    FluxTraining.epoch!(learner, MyTrainingPhase(), dl)
    FluxTraining.epoch!(learner, MyValidationPhase(), dl)
end