Closed Red-Portal closed 1 month ago
Hmm.. seems like a mapreduce
diffrule is broken... again?
@devmotion I think you're probably the expert on differentiating mapreduce
. Do you have any clue about what's going on?
A hotfix is just:
function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
N = length(sb.bs)
yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges[1]])
logjac = sum(linit)
# HACK: Return early to avoid `mapreduce` over empty-collection which Zygote.jl doesn't like.
N == 1 && return (yinit, logjac)
ys = mapreduce(vcat, sb.bs[2:end], sb.ranges[2:end]; init=yinit) do b, r
y, l = with_logabsdet_jacobian(b, x[r])
logjac += sum(l)
y
end
return (ys, logjac)
end
@torfjelde Hmm.. I don't think the hotfix fixes the MWE. In fact the MWE results in N=2
so it doesn't hit the exit condition.
For anyone that hits this and needs to get this fixed ASAP, here's a quick and dirty solution (at the cost of twice the computation time...):
@eval Bijectors begin
function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
ys = mapreduce(vcat, sb.bs, sb.ranges) do b, r
y, _ = with_logabsdet_jacobian(b, x[r])
y
end
logjac = mapreduce(+, sb.bs, sb.ranges) do b, r
_, l = with_logabsdet_jacobian(b, x[r])
first(l)
end
return (ys, logjac)
end
end
Aah sorry, I saw the [1, 2]
and immediately thought it was another issue.
It seems the line https://github.com/FluxML/Zygote.jl/blob/dc16b2e35cdfa9dda62f636e95271490548cc574/src/compiler/chainrules.jl#L235 returns nothing
, and which in turn means we're hitting https://github.com/JuliaDiff/ChainRulesCore.jl/blob/79ba4ef03afdf5715b6fa0294e5accfe4e95c79b/src/rules.jl#L138 for some reason.
In particular, the traceback
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
iterate(::Union{LinRange, StepRangeLen}) at range.jl:872
iterate(::Union{LinRange, StepRangeLen}, ::Integer) at range.jl:872
iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} at dict.jl:712
...
Stacktrace:
[1] indexed_iterate(I::Nothing, i::Int64)
@ Base ./tuple.jl:91
[2] chain_rrule_kw(::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::Function, ::Function, ::Vararg{Any})
@ Main ./REPL[23]:3
[3] macro expansion
@ ./compiler/interface2.jl:0 [inlined]
[4] _pullback(::Zygote.Context{false}, ::Base.var"#reduce##kw", ::NamedTuple{(:init,), Tuple{Vector{Float64}}}, ::typeof(reduce), ::typeof(vcat), ::Vector{Vector{Float64}})
@ Zygote ./compiler/interface2.jl:9
[5] _pullback
@ ./reducedim.jl:359 [inlined]
[6] _pullback(::Zygote.Context{false}, ::Base.var"##mapreduce#766", ::Base.Pairs{Symbol, Vector{Float64}, Tuple{Symbol}, NamedTuple{(:init,), Tuple{Vector{Float64}}}}, ::typeof(mapreduce), ::var"#3#4"{Vector{Float64}}, ::typeof(vcat), ::Vector{Any}, ::Vector{UnitRange{Int64}})
@ Zygote ./compiler/interface2.jl:0
seems to indicate that we're somehow triggering the rule for reduce
.
Currently trying to figure out what's going on here.
Here's a MWE
This results in
Note that if ranges is set as
everything works. Not sure if this is intended behavior?