slimgroup / InvertibleNetworks.jl

A Julia framework for invertible neural networks
MIT License
148 stars 20 forks source link

NetworkGlow has no field logdet #80

Closed flo-he closed 9 months ago

flo-he commented 1 year ago

Hi, I stumbled across this error message (see title) when trying to train a Glow network (but also applies to Hint network).

MWE:

using InvertibleNetworks, Flux

# Glow Network
model = NetworkGlow(2, 32, 2, 5)

# dummy input & target
X = randn(Float32, 1, 1, 2, 1)
Y = 2 .* X .+ 1

# loss fn
loss(model, X, Y) = Flux.mse(Y, model(X)[1])

θ = Flux.params(model)
opt = ADAM(0.001f0)

for i = 1:5
    l, grads = Flux.withgradient(θ) do
        loss(model, X, Y)
    end
    @show l
    Flux.update!(opt, θ, grads)
end
flo-he commented 9 months ago

Hi, are there any news on this? Would be really useful if one could train the INNs as simple as any other Flux model.

rafaelorozco commented 9 months ago

Hello,

Sorry! I missed this discussion or probably forgot about this. There is an easy fix where we give GlowNetwork the optional logdet and then if logdet=false you can train it as you describe above. Would that be helpful?

If so I can make that PR in a couple of hours no problem

flo-he commented 9 months ago

Yes, this would be fabulous, thank you!

rafaelorozco commented 9 months ago

All right pushed that quick fix. I want to be clear again that this will only work for logdet=false. Currently tracking/differentiating the logdet is a bit difficult to do with Julia AD. I think it is possible it just needs some time when I have that later.

I added the MWE that you suggested here: https://github.com/slimgroup/InvertibleNetworks.jl/blob/master/examples/chainrules/train_with_flux.jl

I just had to increase the dimensionality of the input because the actnorm layer was exploding over the variance over a single element.

I hope this helps, Thank you for the input!