FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.5k stars 607 forks source link

Pull request #2007 causes Flux.params() calls to not get cached #2040

Open christiangnrd opened 2 years ago

christiangnrd commented 2 years ago

I'm really not sure where the breakage happens but I'm more than happy to run test if you need me to. I isolated the issue to #2007 with bisect.

Ever since I upgraded to 0.13.5, my variational convolutional autoencoder is not running on the gpu and does not display any error messages. I can see using nvidia-smi that things are properly being transferred to gpu memory, but when it comes to the actually computations, the gpu usage fluctuates between 0% and 2% (which I believe is the arrays being moved to and from main memory) instead of consistently going up to ~60% usage.

I tried a non convolutional variational autoencoder with the rest being mostly the same that was working normally, and I also tried models with convolutional layers without functors that also seemed to compute on gpu, so I believe that the combination of a struct with convolutional layers that is tagged with @functor is what prevents the math from happening on the gpu.

I'll do my best to be responsive.

mcabbott commented 2 years ago

It's not so easy to guess what's gone wrong. I presume your model contains transpose or adjoint matrices as parameters, which Functors@0.3 will recurse into, instead of regarding them as leaf nodes. Maybe that's where to look to construct a MWE.

christiangnrd commented 2 years ago

Here is an MWE. On my desktop, running it with v0.13.4 takes 2 seconds per epoch while running it with v0.13.5 takes a couple minutes per epoch.

using Flux
using Flux: @functor, flatten
using Flux.Losses: logitbinarycrossentropy
using Flux.Data: DataLoader
using MLDatasets
using ProgressMeter: ProgressMeter, Progress, next!
ProgressMeter.ijulia_behavior(:clear)
using Random
using MLUtils: unsqueeze

# load MNIST images and return loader
function get_data(batch_size, split=:train)
    xtrain, ytrain = MLDatasets.MNIST(split)[:]
    xtrain = unsqueeze(xtrain, 3)
    DataLoader((xtrain, ytrain), batchsize=batch_size, shuffle=true)
end

struct Encoder
    conv
    μ
    logσ
end
@functor Encoder

Encoder(input_dim::Int, latent_dim::Int, hidden_dim::Int) = Encoder(
    Chain(
        Conv((3,3),1 => 32,relu,stride = 2, pad = 1),
        Conv((3,3),32 => 64,relu,stride = 2, pad = 1),
        flatten,
        Dense((input_dim ÷ (2*2))^2 * 64,hidden_dim,relu),
        ),
    # identity as activation function
    Dense(hidden_dim,latent_dim), # μ
    Dense(hidden_dim,latent_dim), # logσ
)

function (encoder::Encoder)(x)
    h = encoder.conv(x)
    encoder.μ(h), encoder.logσ(h)
end

Decoder(input_dim::Int, latent_dim::Int, hidden_dim::Int) = Chain(
    Dense(latent_dim,(input_dim ÷ (2*2))^2 * 64,relu),
    x -> reshape(x,(7,7,64,:)),

    # note SamePad() is not possible here
    ConvTranspose((3,3),64 => 64,relu,stride=2, pad = SamePad()),
    ConvTranspose((3,3),64 => 32,relu,stride=2, pad = SamePad()),
    ConvTranspose((3,3),32 => 1, pad = 1)
)

function reconstuct(encoder, decoder, x, device)
    μ, logσ = encoder(x)
    z = μ + device(randn(Float32, size(logσ))) .* exp.(logσ)
    μ, logσ, decoder(z)
end

function model_loss(encoder, decoder, λ, x, device)
    μ, logσ, decoder_z = reconstuct(encoder, decoder, x, device)
    len = size(x)[end]
    # KL-divergence
    kl_q_p = 0.5f0 * sum(@. (exp(2f0 * logσ) + μ^2 -1f0 - 2f0 * logσ)) / len

    logp_x_z = logitbinarycrossentropy(decoder_z, x, agg=sum) / len
    # regularization
    reg = λ * sum(x->sum(x.^2), Flux.params(decoder))

    logp_x_z + kl_q_p + reg
end

η = 1e-3                # learning rate
λ = 0.01f0              # regularization paramater
epochs = 20             # number of epochs
cuda = true             # use GPU
input_dim = 28        # image size
latent_dim = 2          # latent dimension
hidden_dim = 16        # hidden dimension

# GPU config
if cuda && CUDA.has_cuda()
    device = gpu
    @info "Training on GPU"
else
    device = cpu
    @info "Training on CPU"
end

# load MNIST images
loader = get_data(1024)

# initialize encoder and decoder
encoder = Encoder(input_dim, latent_dim, hidden_dim) |> device
decoder = Decoder(input_dim, latent_dim, hidden_dim) |> device

input, _ = first(loader)

# ADAM optimizer
opt = ADAM(η)

# parameters
ps = Flux.params(encoder.conv,encoder.μ, encoder.logσ, decoder)

# training
# train_steps = 0
@info "Start Training, total $(epochs) epochs"
for epoch = 1:epochs
    progress = Progress(length(loader))

     for (x, _) in loader 
        loss, back = Flux.pullback(() -> model_loss(encoder, decoder, λ, x |> device, device), ps)
        grad = back(1f0)
        Flux.Optimise.update!(opt, ps, grad)

        # train_steps += 1
        next!(progress; showvalues=[(:epoch, epoch), (:loss, loss)])

    end

end
ToucheSir commented 2 years ago

I can confirm 0.13.5 is slower (around 2-3X) per step over 0.13.4. However, both do make use of the GPU on my machine and neither runs an entire epoch in only 2s!

One suspicious find on 0.13.5 which may be contributing. When I @time each step, I see a significant amount of compilation time even after the first step. Here's a snippet from the beginning:

118.301900 seconds (260.05 M allocations: 13.222 GiB, 2.63% gc time, 67.30% compilation time: 0% of which was recompilation)
  0.602355 seconds (155.17 k allocations: 10.094 MiB, 11.42% gc time, 22.53% compilation time)
  1.051859 seconds (265.48 k allocations: 14.941 MiB, 34.92% gc time, 18.84% compilation time)
  0.672971 seconds (325.69 k allocations: 17.240 MiB, 4.61% gc time, 23.60% compilation time)
  0.932188 seconds (411.05 k allocations: 20.821 MiB, 30.46% gc time, 18.00% compilation time)
  0.777822 seconds (496.12 k allocations: 24.397 MiB, 2.35% gc time, 29.21% compilation time)
  0.769096 seconds (581.19 k allocations: 27.964 MiB, 2.32% gc time, 36.30% compilation time)
  0.758847 seconds (666.48 k allocations: 31.538 MiB, 4.06% gc time, 26.60% compilation time)
  1.085506 seconds (751.89 k allocations: 35.120 MiB, 28.85% gc time, 19.79% compilation time)
  0.722599 seconds (836.76 k allocations: 38.708 MiB, 1.73% gc time, 30.56% compilation time)
  0.809376 seconds (922.12 k allocations: 42.270 MiB, 2.30% gc time, 29.58% compilation time)

0.13.4 does not exhibit this continual compilation, but does still incur some compilation in steps 2-3:

103.966379 seconds (257.45 M allocations: 13.102 GiB, 2.80% gc time, 74.92% compilation time: 0% of which was recompilation)
  0.463840 seconds (14.39 k allocations: 4.098 MiB, 10.50% gc time, 1.62% compilation time)
  0.501627 seconds (39.38 k allocations: 5.366 MiB, 5.73% gc time, 4.95% compilation time)
ToucheSir commented 2 years ago

Reduced the repeated compilation to taking the gradient of just the regularization term λ * sum(x->sum(x.^2), Flux.params(decoder)) for even a simple model (1 Dense layer). The culprit is https://github.com/FluxML/Flux.jl/commit/0b62a91162ff508f4253054e9609fb7bd2f69202#diff-fb7b52bcd5616e0bebd43199ba13ba86729cd6a0ea17598ec355c3b3fe47c521L39-R48, but I don't understand why pullbacks aren't being cached after the first compilation (in both 0.13.5 and for the first couple of iterations on 0.13.4) :confused:.

christiangnrd commented 2 years ago

I can confirm that when I make no changes other than removing the regularization line in my original code, the performance is back to what it was in 0.13.4.

Is this something that can be fixed in this repo, or does the issue lie somewhere else?

Also, any workarounds that would avoid the constant recompilation while keeping the regularization while waiting for a fix?

MariusDrulea commented 1 year ago

Unfortunately, the problem is present again in Flux 0.13.10, so we have to reopen this task.

In the following MWE, loss_slow compiles at every iteration. Additionally, the runtime and the memory usage at each iteration are increasing. It looks like loss_slow causes Zygote to continuously accumulate some data. In the vae_mnist example, the runtime starts at 4 minutes/epoch and reaches 1.5 hours per epoch.

The equivalent loss_explicit function behaves as expected.

using Flux
using Flux: norm

model = Dense(2, 2)

loss_slow(m) = sum(p->norm(p), Flux.params(m))
loss_explicit(m) = norm(m.weight) + norm(m.bias)

for i in 1:10
    @time ∇m_slow = gradient(m->loss_slow(m), model)    
end

for i in 1:10
    @time ∇m_explicit = gradient(m->loss_explicit(m), model)    
end

Here is the output:


loss_slow:
 23.518778 seconds (62.17 M allocations: 3.153 GiB, 3.73% gc time, 99.94% compilation time)
  0.018303 seconds (4.03 k allocations: 183.281 KiB, 93.40% compilation time)
  0.018860 seconds (5.14 k allocations: 231.125 KiB, 93.63% compilation time)
  0.019585 seconds (6.24 k allocations: 281.562 KiB, 91.42% compilation time)
  0.019242 seconds (7.33 k allocations: 324.969 KiB, 92.79% compilation time)
  0.019103 seconds (8.44 k allocations: 376.188 KiB, 90.87% compilation time)
  0.019514 seconds (9.53 k allocations: 419.500 KiB, 91.37% compilation time)
  0.019786 seconds (10.63 k allocations: 467.250 KiB, 90.60% compilation time)
  0.022090 seconds (11.73 k allocations: 514.031 KiB, 91.70% compilation time)
  0.019207 seconds (12.83 k allocations: 561.297 KiB, 90.98% compilation time)
  0.038078 seconds (73.32 k allocations: 3.669 MiB, 99.70% compilation time)

loss_explicit:
  0.000017 seconds (29 allocations: 1.766 KiB)
  0.000015 seconds (29 allocations: 1.766 KiB)
  0.000006 seconds (29 allocations: 1.766 KiB)
  0.000005 seconds (29 allocations: 1.766 KiB)
  0.000005 seconds (29 allocations: 1.766 KiB)
  0.000006 seconds (29 allocations: 1.766 KiB)
  0.000006 seconds (29 allocations: 1.766 KiB)
  0.000004 seconds (29 allocations: 1.766 KiB)
svilupp commented 1 year ago

I was wondering if there any workarounds that one can do to have regularization in the loss and avoid this issue? Eg, some Zygote tricks in how the loss is constructed?

It seems that the only solution right now is to pin to 0.13.4, right?

ToucheSir commented 1 year ago

The best and most future-proof solution is to use explicit params for the regularization term as shown above, but we don't currently have nice helper functionality for that. If you're using implicit params, you can call params outside the loss (good idea either way) and iterate over it inside.

darsnack commented 1 year ago

we don't currently have nice helper functionality for that

The issue to track for the helper is https://github.com/FluxML/Optimisers.jl/pull/57. Until then, here is a snippet that should do that same for simple penalties like L2.

using Flux
using Functors

penalty(x::AbstractArray) = sum(x.^2) # example penalty

# further down
grads = Flux.gradient(model) do m
    loss = # ...
    reg = Functors.fmap(penalty, m; exclude = Flux.trainable)
    return loss + lambda * reg
end
mcabbott commented 1 year ago

exclude = Flux.trainable can't be right here, it doesn't return a Bool.

It really needs some trainablewalk. exclude=Optimisers.isnumeric ought to run, but will include any non-trainable parameter arrays.