slimgroup / InvertibleNetworks.jl

A Julia framework for invertible neural networks
MIT License
148 stars 20 forks source link

ridiculously long training times for GLOW on gpu #72

Closed shaneinglis closed 1 year ago

shaneinglis commented 1 year ago

I tried to use the NetworkGlow structure/example to replicate GLOW on CIFAR10. I put 3 scales with 32 steps in each and each convolution as 256 channels. I found that for each backward pass, it took a few minutes. Whilst I do agree that the model is huge, with somewhat around 50 million parameters, but the same model on python takes the same time for an epoch (with 256 batch size). I wonder if this behaviour is expected since invertiblenetworks.jl focuses more on memory-efficient than fast.

mloubout commented 1 year ago

Hi, thanks for the details. Would you have a small script that shows this is issue? While memory is one of the main part of this implementation, the performance should not be this bad so we'd like to make check this out.

shaneinglis commented 1 year ago
using InvertibleNetworks, LinearAlgebra, Flux
using MLDatasets
import Flux.Optimise.update!

device = InvertibleNetworks.CUDA.functional() ? gpu : cpu

@info device

# Define network
n_in = 3
n_hidden = 256
batchsize = 256
L = 3   # number of scales
K = 32   # number of flow steps per scale

# Input

# Glow network
G = NetworkGlow(n_in, n_hidden, L, K, split_scales = true) |> device

# Objective function
function loss(X)
    Y, logdet = G.forward(X)
    f = -log_likelihood(Y) - logdet
    ΔY = -∇log_likelihood(Y)
    @time ΔX = G.backward(ΔY, Y)[1]
    return f
end

# Training
maxiter = 10
opt = Flux.ADAM(1f-3)
fval = zeros(Float32, maxiter)

#input for cifar10 is of size 32x32x3
train_x = CIFAR10.traintensor(Float32)
test_x = CIFAR10.testtensor(Float32)

train_loader = Flux.DataLoader(train_x|> device, batchsize=batchsize,shuffle=true)

pass = 0
for j=1:maxiter
    for X in train_loader
        fval[j] = loss(X)
            # Update params
        for p in get_params(G)
            update!(opt, p.data, p.grad)
        end
        clear_grad!(G)
    end
end

This is an example code (adapted from NetworkGlow.jl example). I have been running on an NVIDIA A100 GPU, I also noticed that the speed of the backward pass is highly dependent on the batch size (more than I would expect it to be). For this piece of code, I get the output below. Thank you!!

image

mloubout commented 1 year ago

Thank you, looking into it

mloubout commented 1 year ago

The main computational bottleneck you are experience comes from computing the gradient got the 1x1 convolutions. In practice, these are usually skipped with these weights being fixed (@rafaelorozco please add any detail). So in your case if you initialize your network as NetworkGlow(n_in, n_hidden, L, K, split_scales = true, freeze_conv=true) you will experience a significant speedup

shaneinglis commented 1 year ago

Thank you for looking into it! Now the time taken for a backward pass had gone down significantly. It takes around 12 seconds now on the same machine, which I guess is expected.