Closed ziyiyin97 closed 1 year ago
Can you try Chain(N, AN)
instead. I think we only implemented the behavior for Chain it's a bit trickier for calls
Hmm same error
N1 = Chain(N, AN)
g = gradient(X -> sum(N1(X)), X)
ERROR: MethodError: Cannot `convert` an object of type ActNorm to an object of type CouplingLayerHINT
Closest candidates are:
convert(::Type{T}, ::T) where T at Base.jl:61
CouplingLayerHINT(::Any, ::Any, ::Any, ::Any, ::Any) at ~/.julia/packages/InvertibleNetworks/03pWT/src/layers/invertible_layer_hint.jl:54
Stacktrace:
[1] push!(a::Vector{CouplingLayerHINT}, item::ActNorm)
@ Base ./array.jl:1057
[2] forward_update!(state::InvertibleNetworks.InvertibleOperationsTape, X::Array{Float32, 4}, Y::Array{Float32, 4}, logdet::Nothing, net::ActNorm)
@ InvertibleNetworks ~/.julia/packages/InvertibleNetworks/03pWT/src/utils/chainrules.jl:78
[3] rrule(net::ActNorm, X::Array{Float32, 4}; state::InvertibleNetworks.InvertibleOperationsTape)
@ InvertibleNetworks ~/.julia/packages/InvertibleNetworks/03pWT/src/utils/chainrules.jl:134
[4] rrule
@ ~/.julia/packages/InvertibleNetworks/03pWT/src/utils/chainrules.jl:127 [inlined]
[5] rrule
@ ~/.julia/packages/ChainRulesCore/a4mIA/src/rules.jl:134 [inlined]
[6] chain_rrule
@ ~/.julia/packages/Zygote/g2w9o/src/compiler/chainrules.jl:223 [inlined]
[7] macro expansion
@ ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0 [inlined]
[8] _pullback(ctx::Zygote.Context{false}, f::ActNorm, args::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:9
[9] macro expansion
@ ~/.julia/packages/Flux/OxB4x/src/layers/basic.jl:53 [inlined]
[10] _pullback
@ ~/.julia/packages/Flux/OxB4x/src/layers/basic.jl:53 [inlined]
[11] _pullback(::Zygote.Context{false}, ::typeof(Flux._applychain), ::Tuple{CouplingLayerHINT, ActNorm}, ::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
[12] _pullback
@ ~/.julia/packages/Flux/OxB4x/src/layers/basic.jl:51 [inlined]
[13] _pullback(ctx::Zygote.Context{false}, f::Chain{Tuple{CouplingLayerHINT, ActNorm}}, args::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
[14] _pullback
@ ./REPL[6]:1 [inlined]
[15] _pullback(ctx::Zygote.Context{false}, f::var"#11#12", args::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
[16] pullback(f::Function, cx::Zygote.Context{false}, args::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:44
[17] pullback
@ ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:42 [inlined]
[18] gradient(f::Function, args::Array{Float32, 4})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:96
[19] top-level scope
@ REPL[6]:1
When
ActNorm
layer is nested with other network (below isCouplingLayerHINT
but can also be other ones),rrule
throws error that wants to convert one to the otherMFE below