Closed flo-he closed 9 months ago
Hi, are there any news on this? Would be really useful if one could train the INNs as simple as any other Flux model.
Hello,
Sorry! I missed this discussion or probably forgot about this. There is an easy fix where we give GlowNetwork the optional logdet and then if logdet=false you can train it as you describe above. Would that be helpful?
If so I can make that PR in a couple of hours no problem
Yes, this would be fabulous, thank you!
All right pushed that quick fix. I want to be clear again that this will only work for logdet=false. Currently tracking/differentiating the logdet is a bit difficult to do with Julia AD. I think it is possible it just needs some time when I have that later.
I added the MWE that you suggested here: https://github.com/slimgroup/InvertibleNetworks.jl/blob/master/examples/chainrules/train_with_flux.jl
I just had to increase the dimensionality of the input because the actnorm layer was exploding over the variance over a single element.
I hope this helps, Thank you for the input!
Hi, I stumbled across this error message (see title) when trying to train a Glow network (but also applies to Hint network).
MWE: