JuliaDecisionFocusedLearning / ImplicitDifferentiation.jl

Automatic differentiation of implicit functions
https://juliadecisionfocusedlearning.github.io/ImplicitDifferentiation.jl/
MIT License
122 stars 7 forks source link

ChainRulesTestUtils fails with integer byproduct #112

Open gdalle opened 1 year ago

gdalle commented 1 year ago

Here the implicit object returns a tuple (y, z) where y is the actually interesting part and we don't differentiate wrt z (this byproduct is ignored in the pullback, defined here).

julia> using ChainRulesCore

julia> using ChainRulesTestUtils

julia> using ImplicitDifferentiation  # use the main branch

julia> using Zygote

julia> forward(x) = sqrt.(abs.(x)), 2;

julia> conditions(x, y, z) = abs.(y) .^ z .- abs.(x);

julia> implicit = ImplicitFunction(forward, conditions)
ImplicitFunction(forward, conditions, IterativeLinearSolver(true), nothing)

julia> x = rand(3)
3-element Vector{Float64}:
 0.15182038809215614
 0.4724558277117631
 0.6889531602780976

julia> y, z = implicit(x)
([0.3896413582926691, 0.6873542228805779, 0.8300320236461347], 2)

julia> dy = similar(y);

julia> dy .= 1;

julia> rc = Zygote.ZygoteRuleConfig();

julia> _, back = rrule_via_ad(rc, implicit, x);

julia> back((dy, NoTangent()))
(NoTangent(), [1.2832313340424137, 0.7274269704848686, 0.6023863968568562])

julia> test_rrule(rc, implicit, x; atol=1e-2, output_tangent=(dy, 0))  # ok
Test Summary:                                                                                              | Pass  Total  Time
test_rrule: ImplicitFunction(forward, conditions, IterativeLinearSolver(true), nothing) on Vector{Float64} |    7      7  0.0s
Test.DefaultTestSet("test_rrule: ImplicitFunction(forward, conditions, IterativeLinearSolver(true), nothing) on Vector{Float64}", Any[], 7, false, false, true, 1.691675539371241e9, 1.691675539408112e9, false, "/home/guillaume/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl")

julia> test_rrule(rc, implicit, x; atol=1e-2, output_tangent=(dy, NoTangent()))  # not ok
test_rrule: ImplicitFunction(forward, conditions, IterativeLinearSolver(true), nothing) on Vector{Float64}: Error During Test at /home/guillaume/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl:202
  Got exception outside of a @test
  DimensionMismatch: second dimension of A, 4, does not match length of x, 3
  Stacktrace:
    [1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
      @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:404
    [2] generic_matvecmul!
      @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:71 [inlined]
    [3] mul!
      @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:66 [inlined]
    [4] mul!
      @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
    [5] *(A::LinearAlgebra.Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
      @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:53
    [6] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/BSR84/src/grad.jl:84
    [7] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::ChainRulesTestUtils.var"#fnew#53"{ChainRulesTestUtils.var"#call#63"{@NamedTuple{}}, Tuple{ImplicitFunction{typeof(forward), typeof(conditions), IterativeLinearSolver, Nothing}, Vector{Float64}}, Tuple{Bool, Bool}}, ȳ::Tuple{Vector{Float64}, NoTangent}, x::Vector{Float64})
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/BSR84/src/grad.jl:77
    [8] _make_j′vp_call(fdm::Any, f::Any, ȳ::Any, xs::Any, ignores::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/C9L2i/src/finite_difference_calls.jl:51
    [9] macro expansion
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl:233 [inlined]
   [10] macro expansion
      @ ChainRulesTestUtils ~/.julia/juliaup/julia-1.10.0-beta1+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [11] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, testset_name::Any, kwargs...)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl:205

This is the broken test added in #111.

gdalle commented 1 year ago

@oxinabox if you have any intuition about this I'll gladly take it

oxinabox commented 1 year ago

I'm guessing in some ways it related to FiniteDifferences.to_vec