slimgroup / InvertibleNetworks.jl

A Julia framework for invertible neural networks
MIT License
148 stars 20 forks source link

Flux friendlyness #78

Closed flo-he closed 1 year ago

flo-he commented 1 year ago

Hi,

I think it would be very useful if there is a plug-and-play option to train INNs of this package using standard Flux API. Failing to train a simple INN using the following training script:

using InvertibleNetworks, Flux

# Define network
nx = 1
ny = 1
n_in = 2
n_hidden = 10
batchsize = 32

# net
AN = ActNorm(n_in; logdet = false)
C = CouplingLayerGlow(n_in, n_hidden; logdet = false, k1 = 1, k2 = 1, p1 = 0, p2 = 0)
model = Chain(AN, C)

# dummy input & target
X = randn(Float32, nx, ny, n_in, batchsize)
Y = 2 .* X .+ 1

# loss fn
loss(model, X, Y) = Flux.mse(Y, model(X))

# old, implicit-style Flux
θ = Flux.params(model)
opt = ADAM(0.001f0)

for i = 1:5
    l, grads = Flux.withgradient(θ) do
        loss(model, X, Y)
    end

    @info "Loss: $l"

    Flux.update!(opt, θ, grads)
end

Running this code, the loss stays the same (parameters do not seem to be updated). I do not know if this style of training is by default not supported currently, or if it's simply some bug.

I think it would be a useful feature to ease things up. E.g. for my use case, I want to incorporate INNs in some personal larger project which uses Flux and I only need the guaranteed invertibility of the models after training, else, they should just behave as any other custom Flux model.

mloubout commented 1 year ago

Hi

This was indeed an oversight, this is being fixed in #79 , thanks for reporting it