slimgroup / InvertibleNetworks.jl

A Julia framework for invertible neural networks
MIT License
152 stars 23 forks source link

`rrule` error when `ActNorm` is nested with other network #74

Closed ziyiyin97 closed 1 year ago

ziyiyin97 commented 1 year ago

When ActNorm layer is nested with other network (below is CouplingLayerHINT but can also be other ones), rrule throws error that wants to convert one to the other

ERROR: LoadError: MethodError: Cannot `convert` an object of type CouplingLayerHINT to an object of type ActNorm
Closest candidates are:
  convert(::Type{T}, ::T) where T at Base.jl:61
  ActNorm(::Any; logdet) at ~/.julia/packages/InvertibleNetworks/03pWT/src/layers/invertible_layer_actnorm.jl:53
  ActNorm(::Any, ::Any, ::Any, ::Any, ::Any) at ~/.julia/packages/InvertibleNetworks/03pWT/src/layers/invertible_layer_actnorm.jl:43
Stacktrace:
  [1] push!(a::Vector{ActNorm}, item::CouplingLayerHINT)
    @ Base ./array.jl:1057
  [2] forward_update!(state::InvertibleNetworks.InvertibleOperationsTape, X::Array{Float32, 4}, Y::Array{Float32, 4}, logdet::Nothing, net::CouplingLayerHINT)
    @ InvertibleNetworks ~/.julia/packages/InvertibleNetworks/03pWT/src/utils/chainrules.jl:78
  [3] rrule(net::CouplingLayerHINT, X::Array{Float32, 4}; state::InvertibleNetworks.InvertibleOperationsTape)
    @ InvertibleNetworks ~/.julia/packages/InvertibleNetworks/03pWT/src/utils/chainrules.jl:134
  [4] rrule
    @ ~/.julia/packages/InvertibleNetworks/03pWT/src/utils/chainrules.jl:127 [inlined]
  [5] rrule
    @ ~/.julia/packages/ChainRulesCore/a4mIA/src/rules.jl:134 [inlined]
  [6] chain_rrule
    @ ~/.julia/packages/Zygote/g2w9o/src/compiler/chainrules.jl:223 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0 [inlined]
  [8] _pullback(ctx::Zygote.Context{false}, f::CouplingLayerHINT, args::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:9
  [9] _pullback
    @ ~/.julia/dev/InvertibleNetworks/examples/chainrules/MFE.jl:16 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::var"#5#6", args::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [11] pullback(f::Function, cx::Zygote.Context{false}, args::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:44
 [12] pullback
    @ ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:42 [inlined]
 [13] gradient(f::Function, args::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:96
 [14] top-level scope
    @ ~/.julia/dev/InvertibleNetworks/examples/chainrules/MFE.jl:16
 [15] include(fname::String)
    @ Base.MainInclude ./client.jl:476
 [16] top-level scope
    @ REPL[1]:1
in expression starting at /Users/francisyin/.julia/dev/InvertibleNetworks/examples/chainrules/MFE.jl:16

MFE below

using InvertibleNetworks, Flux

nx = 32
ny = 32
n_ch = 16
n_hidden = 64
batchsize = 2
logdet = false
N = CouplingLayerHINT(n_ch, n_hidden; logdet=logdet, permute="full")
AN = ActNorm(batchsize)
X = randn(Float32, nx, ny, n_ch, batchsize)
AN(X)
g1 = gradient(X -> sum(AN(X)), X)       # this is fine
g2 = gradient(X -> sum(N(X)), X)        # this is fine
sum(N(AN(X)))                           # this is fine
g = gradient(X -> sum(N(AN(X))), X)     # ERROR: MethodError: Cannot `convert` an object of type CouplingLayerHINT to an object of type ActNorm
g = gradient(X -> sum(AN(N(X))), X)     # ERROR: MethodError: Cannot `convert` an object of type ActNorm to an object of type CouplingLayerHINT
mloubout commented 1 year ago

Can you try Chain(N, AN) instead. I think we only implemented the behavior for Chain it's a bit trickier for calls

ziyiyin97 commented 1 year ago

Hmm same error

N1 = Chain(N, AN)
g = gradient(X -> sum(N1(X)), X)
ERROR: MethodError: Cannot `convert` an object of type ActNorm to an object of type CouplingLayerHINT
Closest candidates are:
  convert(::Type{T}, ::T) where T at Base.jl:61
  CouplingLayerHINT(::Any, ::Any, ::Any, ::Any, ::Any) at ~/.julia/packages/InvertibleNetworks/03pWT/src/layers/invertible_layer_hint.jl:54
Stacktrace:
  [1] push!(a::Vector{CouplingLayerHINT}, item::ActNorm)
    @ Base ./array.jl:1057
  [2] forward_update!(state::InvertibleNetworks.InvertibleOperationsTape, X::Array{Float32, 4}, Y::Array{Float32, 4}, logdet::Nothing, net::ActNorm)
    @ InvertibleNetworks ~/.julia/packages/InvertibleNetworks/03pWT/src/utils/chainrules.jl:78
  [3] rrule(net::ActNorm, X::Array{Float32, 4}; state::InvertibleNetworks.InvertibleOperationsTape)
    @ InvertibleNetworks ~/.julia/packages/InvertibleNetworks/03pWT/src/utils/chainrules.jl:134
  [4] rrule
    @ ~/.julia/packages/InvertibleNetworks/03pWT/src/utils/chainrules.jl:127 [inlined]
  [5] rrule
    @ ~/.julia/packages/ChainRulesCore/a4mIA/src/rules.jl:134 [inlined]
  [6] chain_rrule
    @ ~/.julia/packages/Zygote/g2w9o/src/compiler/chainrules.jl:223 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0 [inlined]
  [8] _pullback(ctx::Zygote.Context{false}, f::ActNorm, args::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:9
  [9] macro expansion
    @ ~/.julia/packages/Flux/OxB4x/src/layers/basic.jl:53 [inlined]
 [10] _pullback
    @ ~/.julia/packages/Flux/OxB4x/src/layers/basic.jl:53 [inlined]
 [11] _pullback(::Zygote.Context{false}, ::typeof(Flux._applychain), ::Tuple{CouplingLayerHINT, ActNorm}, ::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [12] _pullback
    @ ~/.julia/packages/Flux/OxB4x/src/layers/basic.jl:51 [inlined]
 [13] _pullback(ctx::Zygote.Context{false}, f::Chain{Tuple{CouplingLayerHINT, ActNorm}}, args::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [14] _pullback
    @ ./REPL[6]:1 [inlined]
 [15] _pullback(ctx::Zygote.Context{false}, f::var"#11#12", args::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [16] pullback(f::Function, cx::Zygote.Context{false}, args::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:44
 [17] pullback
    @ ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:42 [inlined]
 [18] gradient(f::Function, args::Array{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:96
 [19] top-level scope
    @ REPL[6]:1