Here the implicit object returns a tuple (y, z) where y is the actually interesting part and we don't differentiate wrt z (this byproduct is ignored in the pullback, defined here).
When I pass dz=ZeroTangent() manually, the pullback succeeds
When ChainRulesTestUtils does it, the pullback fails due to a dimension mismatch
julia> using ChainRulesCore
julia> using ChainRulesTestUtils
julia> using ImplicitDifferentiation # use the main branch
julia> using Zygote
julia> forward(x) = sqrt.(abs.(x)), 2;
julia> conditions(x, y, z) = abs.(y) .^ z .- abs.(x);
julia> implicit = ImplicitFunction(forward, conditions)
ImplicitFunction(forward, conditions, IterativeLinearSolver(true), nothing)
julia> x = rand(3)
3-element Vector{Float64}:
0.15182038809215614
0.4724558277117631
0.6889531602780976
julia> y, z = implicit(x)
([0.3896413582926691, 0.6873542228805779, 0.8300320236461347], 2)
julia> dy = similar(y);
julia> dy .= 1;
julia> rc = Zygote.ZygoteRuleConfig();
julia> _, back = rrule_via_ad(rc, implicit, x);
julia> back((dy, NoTangent()))
(NoTangent(), [1.2832313340424137, 0.7274269704848686, 0.6023863968568562])
julia> test_rrule(rc, implicit, x; atol=1e-2, output_tangent=(dy, 0)) # ok
Test Summary: | Pass Total Time
test_rrule: ImplicitFunction(forward, conditions, IterativeLinearSolver(true), nothing) on Vector{Float64} | 7 7 0.0s
Test.DefaultTestSet("test_rrule: ImplicitFunction(forward, conditions, IterativeLinearSolver(true), nothing) on Vector{Float64}", Any[], 7, false, false, true, 1.691675539371241e9, 1.691675539408112e9, false, "/home/guillaume/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl")
julia> test_rrule(rc, implicit, x; atol=1e-2, output_tangent=(dy, NoTangent())) # not ok
test_rrule: ImplicitFunction(forward, conditions, IterativeLinearSolver(true), nothing) on Vector{Float64}: Error During Test at /home/guillaume/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl:202
Got exception outside of a @test
DimensionMismatch: second dimension of A, 4, does not match length of x, 3
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:404
[2] generic_matvecmul!
@ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:71 [inlined]
[3] mul!
@ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:66 [inlined]
[4] mul!
@ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
[5] *(A::LinearAlgebra.Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
@ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:53
[6] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/BSR84/src/grad.jl:84
[7] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::ChainRulesTestUtils.var"#fnew#53"{ChainRulesTestUtils.var"#call#63"{@NamedTuple{}}, Tuple{ImplicitFunction{typeof(forward), typeof(conditions), IterativeLinearSolver, Nothing}, Vector{Float64}}, Tuple{Bool, Bool}}, ȳ::Tuple{Vector{Float64}, NoTangent}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/BSR84/src/grad.jl:77
[8] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/C9L2i/src/finite_difference_calls.jl:51
[9] macro expansion
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl:233 [inlined]
[10] macro expansion
@ ChainRulesTestUtils ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
[11] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, testset_name::Any, kwargs...)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl:205
Here the
implicit
object returns a tuple(y, z)
wherey
is the actually interesting part and we don't differentiate wrtz
(this byproduct is ignored in the pullback, defined here).dz=ZeroTangent()
manually, the pullback succeedsThis is the broken test added in #111.