Open christiangnrd opened 2 years ago
It's not so easy to guess what's gone wrong. I presume your model contains transpose or adjoint matrices as parameters, which Functors@0.3 will recurse into, instead of regarding them as leaf nodes. Maybe that's where to look to construct a MWE.
Here is an MWE. On my desktop, running it with v0.13.4 takes 2 seconds per epoch while running it with v0.13.5 takes a couple minutes per epoch.
using Flux
using Flux: @functor, flatten
using Flux.Losses: logitbinarycrossentropy
using Flux.Data: DataLoader
using MLDatasets
using ProgressMeter: ProgressMeter, Progress, next!
ProgressMeter.ijulia_behavior(:clear)
using Random
using MLUtils: unsqueeze
# load MNIST images and return loader
function get_data(batch_size, split=:train)
xtrain, ytrain = MLDatasets.MNIST(split)[:]
xtrain = unsqueeze(xtrain, 3)
DataLoader((xtrain, ytrain), batchsize=batch_size, shuffle=true)
end
struct Encoder
conv
μ
logσ
end
@functor Encoder
Encoder(input_dim::Int, latent_dim::Int, hidden_dim::Int) = Encoder(
Chain(
Conv((3,3),1 => 32,relu,stride = 2, pad = 1),
Conv((3,3),32 => 64,relu,stride = 2, pad = 1),
flatten,
Dense((input_dim ÷ (2*2))^2 * 64,hidden_dim,relu),
),
# identity as activation function
Dense(hidden_dim,latent_dim), # μ
Dense(hidden_dim,latent_dim), # logσ
)
function (encoder::Encoder)(x)
h = encoder.conv(x)
encoder.μ(h), encoder.logσ(h)
end
Decoder(input_dim::Int, latent_dim::Int, hidden_dim::Int) = Chain(
Dense(latent_dim,(input_dim ÷ (2*2))^2 * 64,relu),
x -> reshape(x,(7,7,64,:)),
# note SamePad() is not possible here
ConvTranspose((3,3),64 => 64,relu,stride=2, pad = SamePad()),
ConvTranspose((3,3),64 => 32,relu,stride=2, pad = SamePad()),
ConvTranspose((3,3),32 => 1, pad = 1)
)
function reconstuct(encoder, decoder, x, device)
μ, logσ = encoder(x)
z = μ + device(randn(Float32, size(logσ))) .* exp.(logσ)
μ, logσ, decoder(z)
end
function model_loss(encoder, decoder, λ, x, device)
μ, logσ, decoder_z = reconstuct(encoder, decoder, x, device)
len = size(x)[end]
# KL-divergence
kl_q_p = 0.5f0 * sum(@. (exp(2f0 * logσ) + μ^2 -1f0 - 2f0 * logσ)) / len
logp_x_z = logitbinarycrossentropy(decoder_z, x, agg=sum) / len
# regularization
reg = λ * sum(x->sum(x.^2), Flux.params(decoder))
logp_x_z + kl_q_p + reg
end
η = 1e-3 # learning rate
λ = 0.01f0 # regularization paramater
epochs = 20 # number of epochs
cuda = true # use GPU
input_dim = 28 # image size
latent_dim = 2 # latent dimension
hidden_dim = 16 # hidden dimension
# GPU config
if cuda && CUDA.has_cuda()
device = gpu
@info "Training on GPU"
else
device = cpu
@info "Training on CPU"
end
# load MNIST images
loader = get_data(1024)
# initialize encoder and decoder
encoder = Encoder(input_dim, latent_dim, hidden_dim) |> device
decoder = Decoder(input_dim, latent_dim, hidden_dim) |> device
input, _ = first(loader)
# ADAM optimizer
opt = ADAM(η)
# parameters
ps = Flux.params(encoder.conv,encoder.μ, encoder.logσ, decoder)
# training
# train_steps = 0
@info "Start Training, total $(epochs) epochs"
for epoch = 1:epochs
progress = Progress(length(loader))
for (x, _) in loader
loss, back = Flux.pullback(() -> model_loss(encoder, decoder, λ, x |> device, device), ps)
grad = back(1f0)
Flux.Optimise.update!(opt, ps, grad)
# train_steps += 1
next!(progress; showvalues=[(:epoch, epoch), (:loss, loss)])
end
end
I can confirm 0.13.5 is slower (around 2-3X) per step over 0.13.4. However, both do make use of the GPU on my machine and neither runs an entire epoch in only 2s!
One suspicious find on 0.13.5 which may be contributing. When I @time
each step, I see a significant amount of compilation time even after the first step. Here's a snippet from the beginning:
118.301900 seconds (260.05 M allocations: 13.222 GiB, 2.63% gc time, 67.30% compilation time: 0% of which was recompilation)
0.602355 seconds (155.17 k allocations: 10.094 MiB, 11.42% gc time, 22.53% compilation time)
1.051859 seconds (265.48 k allocations: 14.941 MiB, 34.92% gc time, 18.84% compilation time)
0.672971 seconds (325.69 k allocations: 17.240 MiB, 4.61% gc time, 23.60% compilation time)
0.932188 seconds (411.05 k allocations: 20.821 MiB, 30.46% gc time, 18.00% compilation time)
0.777822 seconds (496.12 k allocations: 24.397 MiB, 2.35% gc time, 29.21% compilation time)
0.769096 seconds (581.19 k allocations: 27.964 MiB, 2.32% gc time, 36.30% compilation time)
0.758847 seconds (666.48 k allocations: 31.538 MiB, 4.06% gc time, 26.60% compilation time)
1.085506 seconds (751.89 k allocations: 35.120 MiB, 28.85% gc time, 19.79% compilation time)
0.722599 seconds (836.76 k allocations: 38.708 MiB, 1.73% gc time, 30.56% compilation time)
0.809376 seconds (922.12 k allocations: 42.270 MiB, 2.30% gc time, 29.58% compilation time)
0.13.4 does not exhibit this continual compilation, but does still incur some compilation in steps 2-3:
103.966379 seconds (257.45 M allocations: 13.102 GiB, 2.80% gc time, 74.92% compilation time: 0% of which was recompilation)
0.463840 seconds (14.39 k allocations: 4.098 MiB, 10.50% gc time, 1.62% compilation time)
0.501627 seconds (39.38 k allocations: 5.366 MiB, 5.73% gc time, 4.95% compilation time)
Reduced the repeated compilation to taking the gradient of just the regularization term λ * sum(x->sum(x.^2), Flux.params(decoder))
for even a simple model (1 Dense layer). The culprit is https://github.com/FluxML/Flux.jl/commit/0b62a91162ff508f4253054e9609fb7bd2f69202#diff-fb7b52bcd5616e0bebd43199ba13ba86729cd6a0ea17598ec355c3b3fe47c521L39-R48, but I don't understand why pullbacks aren't being cached after the first compilation (in both 0.13.5 and for the first couple of iterations on 0.13.4) :confused:.
I can confirm that when I make no changes other than removing the regularization line in my original code, the performance is back to what it was in 0.13.4.
Is this something that can be fixed in this repo, or does the issue lie somewhere else?
Also, any workarounds that would avoid the constant recompilation while keeping the regularization while waiting for a fix?
Unfortunately, the problem is present again in Flux 0.13.10, so we have to reopen this task.
In the following MWE, loss_slow
compiles at every iteration. Additionally, the runtime and the memory usage at each iteration are increasing. It looks like loss_slow
causes Zygote to continuously accumulate some data. In the vae_mnist example, the runtime starts at 4 minutes/epoch and reaches 1.5 hours per epoch.
The equivalent loss_explicit
function behaves as expected.
using Flux
using Flux: norm
model = Dense(2, 2)
loss_slow(m) = sum(p->norm(p), Flux.params(m))
loss_explicit(m) = norm(m.weight) + norm(m.bias)
for i in 1:10
@time ∇m_slow = gradient(m->loss_slow(m), model)
end
for i in 1:10
@time ∇m_explicit = gradient(m->loss_explicit(m), model)
end
Here is the output:
loss_slow:
23.518778 seconds (62.17 M allocations: 3.153 GiB, 3.73% gc time, 99.94% compilation time)
0.018303 seconds (4.03 k allocations: 183.281 KiB, 93.40% compilation time)
0.018860 seconds (5.14 k allocations: 231.125 KiB, 93.63% compilation time)
0.019585 seconds (6.24 k allocations: 281.562 KiB, 91.42% compilation time)
0.019242 seconds (7.33 k allocations: 324.969 KiB, 92.79% compilation time)
0.019103 seconds (8.44 k allocations: 376.188 KiB, 90.87% compilation time)
0.019514 seconds (9.53 k allocations: 419.500 KiB, 91.37% compilation time)
0.019786 seconds (10.63 k allocations: 467.250 KiB, 90.60% compilation time)
0.022090 seconds (11.73 k allocations: 514.031 KiB, 91.70% compilation time)
0.019207 seconds (12.83 k allocations: 561.297 KiB, 90.98% compilation time)
0.038078 seconds (73.32 k allocations: 3.669 MiB, 99.70% compilation time)
loss_explicit:
0.000017 seconds (29 allocations: 1.766 KiB)
0.000015 seconds (29 allocations: 1.766 KiB)
0.000006 seconds (29 allocations: 1.766 KiB)
0.000005 seconds (29 allocations: 1.766 KiB)
0.000005 seconds (29 allocations: 1.766 KiB)
0.000006 seconds (29 allocations: 1.766 KiB)
0.000006 seconds (29 allocations: 1.766 KiB)
0.000004 seconds (29 allocations: 1.766 KiB)
I was wondering if there any workarounds that one can do to have regularization in the loss and avoid this issue? Eg, some Zygote tricks in how the loss is constructed?
It seems that the only solution right now is to pin to 0.13.4, right?
The best and most future-proof solution is to use explicit params for the regularization term as shown above, but we don't currently have nice helper functionality for that. If you're using implicit params, you can call params
outside the loss (good idea either way) and iterate over it inside.
we don't currently have nice helper functionality for that
The issue to track for the helper is https://github.com/FluxML/Optimisers.jl/pull/57. Until then, here is a snippet that should do that same for simple penalties like L2.
using Flux
using Functors
penalty(x::AbstractArray) = sum(x.^2) # example penalty
# further down
grads = Flux.gradient(model) do m
loss = # ...
reg = Functors.fmap(penalty, m; exclude = Flux.trainable)
return loss + lambda * reg
end
exclude = Flux.trainable
can't be right here, it doesn't return a Bool.
It really needs some trainablewalk. exclude=Optimisers.isnumeric
ought to run, but will include any non-trainable parameter arrays.
I'm really not sure where the breakage happens but I'm more than happy to run test if you need me to. I isolated the issue to #2007 with bisect.
Ever since I upgraded to 0.13.5, my variational convolutional autoencoder is not running on the gpu and does not display any error messages. I can see using
nvidia-smi
that things are properly being transferred to gpu memory, but when it comes to the actually computations, the gpu usage fluctuates between 0% and 2% (which I believe is the arrays being moved to and from main memory) instead of consistently going up to ~60% usage.I tried a non convolutional variational autoencoder with the rest being mostly the same that was working normally, and I also tried models with convolutional layers without functors that also seemed to compute on gpu, so I believe that the combination of a struct with convolutional layers that is tagged with @functor is what prevents the math from happening on the gpu.
I'll do my best to be responsive.