It looks like finite difference implementation has hard time going through iterate (see MRE and full stacktrace below):
juia> test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
test_rrule: iterate on Float64,Float64: Error During Test at /home/azbs/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
Got exception outside of a @test
DimensionMismatch: second dimension of A, 2, does not match length of x, 1
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493
...
[7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/finite_difference_calls.jl:51
...
Below I provide rrule() implementation for iterate on tuples for convenience, but perhaps the example can be narrowed down to direct invocation of _make_j′vp_call(). Also, I see the same error when testing with arrays.
MWE
```julia
using ChainRulesCore
import ChainRulesCore.rrule
using ChainRulesTestUtils
function ungetfield(dy, s::Tuple, f::Int)
T = typeof(s)
return Tangent{T}([i == f ? dy : ZeroTangent() for i=1:length(s)]...)
end
function rrule(::typeof(iterate), t::Tuple)
y = iterate(t)
function iterate_pullback(dy)
dy = unthunk(dy)
return NoTangent(), ungetfield(dy[1], t, 1)
end
return y, iterate_pullback
end
function rrule(::typeof(iterate), t::Tuple, i::Integer)
y = iterate(t, i)
function iterate_pullback(dy)
dy = unthunk(dy)
return NoTangent(), ungetfield(dy[1], t, i), ZeroTangent()
end
return y, iterate_pullback
end
test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
```
Complete stacktrace
```julia
julia> test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false)
test_rrule: iterate on Float64,Float64: Error During Test at /home/azbs/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
Got exception outside of a @test
DimensionMismatch: second dimension of A, 2, does not match length of x, 1
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493
[2] mul!
@ /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:93 [inlined]
[3] mul!
@ /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined]
[4] *(tA::LinearAlgebra.Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
@ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:86
[5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/grad.jl:80
[6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::ChainRulesTestUtils.var"#fnew#53"{ChainRulesTestUtils.var"#call#63"{NamedTuple{(), Tuple{}}}, Tuple{typeof(iterate), Tuple{Float64, Float64}}, Tuple{Bool, Bool}}, ȳ::Tangent{Tuple{Float64, Int64}, Tuple{Float64, NoTangent}}, x::Tuple{Float64, Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/grad.jl:73
[7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/finite_difference_calls.jl:51
[8] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:224 [inlined]
[9] macro expansion
@ /opt/julia-1.8.0/share/julia/stdlib/v1.8/Test/src/Test.jl:1357 [inlined]
[10] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:196
[11] test_rrule(::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:170
[12] top-level scope
@ REPL[1]:1
[13] eval
@ ./boot.jl:368 [inlined]
[14] eval
@ ./Base.jl:65 [inlined]
[15] repleval(m::Module, code::Expr, #unused#::String)
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:222
[16] (::VSCodeServer.var"#107#109"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:186
[17] with_logstate(f::Function, logstate::Any)
@ Base.CoreLogging ./logging.jl:511
[18] with_logger
@ ./logging.jl:623 [inlined]
[19] (::VSCodeServer.var"#106#108"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
@ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:187
[20] #invokelatest#2
@ ./essentials.jl:729 [inlined]
[21] invokelatest(::Any)
@ Base ./essentials.jl:726
[22] macro expansion
@ ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/eval.jl:34 [inlined]
[23] (::VSCodeServer.var"#61#62")()
@ VSCodeServer ./task.jl:484
Test Summary: | Pass Error Total Time
test_rrule: iterate on Float64,Float64 | 3 1 4 0.0s
ERROR: Some tests did not pass: 3 passed, 0 failed, 1 errored, 0 broken.
```
It looks like finite difference implementation has hard time going through
iterate
(see MRE and full stacktrace below):Below I provide
rrule()
implementation foriterate
on tuples for convenience, but perhaps the example can be narrowed down to direct invocation of_make_j′vp_call()
. Also, I see the same error when testing with arrays.MWE
```julia using ChainRulesCore import ChainRulesCore.rrule using ChainRulesTestUtils function ungetfield(dy, s::Tuple, f::Int) T = typeof(s) return Tangent{T}([i == f ? dy : ZeroTangent() for i=1:length(s)]...) end function rrule(::typeof(iterate), t::Tuple) y = iterate(t) function iterate_pullback(dy) dy = unthunk(dy) return NoTangent(), ungetfield(dy[1], t, 1) end return y, iterate_pullback end function rrule(::typeof(iterate), t::Tuple, i::Integer) y = iterate(t, i) function iterate_pullback(dy) dy = unthunk(dy) return NoTangent(), ungetfield(dy[1], t, i), ZeroTangent() end return y, iterate_pullback end test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false) ```Complete stacktrace
```julia julia> test_rrule(Base.iterate, (3.0, 5.0); check_inferred=false) test_rrule: iterate on Float64,Float64: Error During Test at /home/azbs/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193 Got exception outside of a @test DimensionMismatch: second dimension of A, 2, does not match length of x, 1 Stacktrace: [1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool) @ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:493 [2] mul! @ /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:93 [inlined] [3] mul! @ /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined] [4] *(tA::LinearAlgebra.Transpose{Float64, Matrix{Float64}}, x::Vector{Float64}) @ LinearAlgebra /opt/julia-1.8.0/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:86 [5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64}) @ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/grad.jl:80 [6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::ChainRulesTestUtils.var"#fnew#53"{ChainRulesTestUtils.var"#call#63"{NamedTuple{(), Tuple{}}}, Tuple{typeof(iterate), Tuple{Float64, Float64}}, Tuple{Bool, Bool}}, ȳ::Tangent{Tuple{Float64, Int64}, Tuple{Float64, NoTangent}}, x::Tuple{Float64, Float64}) @ FiniteDifferences ~/.julia/packages/FiniteDifferences/VpgIT/src/grad.jl:73 [7] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any) @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/finite_difference_calls.jl:51 [8] macro expansion @ ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:224 [inlined] [9] macro expansion @ /opt/julia-1.8.0/share/julia/stdlib/v1.8/Test/src/Test.jl:1357 [inlined] [10] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}}) @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:196 [11] test_rrule(::Any, ::Vararg{Any}; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}}) @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:170 [12] top-level scope @ REPL[1]:1 [13] eval @ ./boot.jl:368 [inlined] [14] eval @ ./Base.jl:65 [inlined] [15] repleval(m::Module, code::Expr, #unused#::String) @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:222 [16] (::VSCodeServer.var"#107#109"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})() @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:186 [17] with_logstate(f::Function, logstate::Any) @ Base.CoreLogging ./logging.jl:511 [18] with_logger @ ./logging.jl:623 [inlined] [19] (::VSCodeServer.var"#106#108"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})() @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/repl.jl:187 [20] #invokelatest#2 @ ./essentials.jl:729 [inlined] [21] invokelatest(::Any) @ Base ./essentials.jl:726 [22] macro expansion @ ~/.vscode/extensions/julialang.language-julia-1.7.12/scripts/packages/VSCodeServer/src/eval.jl:34 [inlined] [23] (::VSCodeServer.var"#61#62")() @ VSCodeServer ./task.jl:484 Test Summary: | Pass Error Total Time test_rrule: iterate on Float64,Float64 | 3 1 4 0.0s ERROR: Some tests did not pass: 3 passed, 0 failed, 1 errored, 0 broken. ```