JuliaDiff / ChainRulesTestUtils.jl

Utilities for testing custom AD primitives.
MIT License
50 stars 15 forks source link

getting error in `FiniteDifferences.to_vec` when `test_rrule` for structured type #258

Open vpuri3 opened 2 years ago

vpuri3 commented 2 years ago
mutable struct Mat{AType<:AbstractMatrix}
    A::AType
    trait::Bool
end
Mat(A::AbstractMatrix) = Mat(A, false)

Base.adjoint(M::Mat) = Mat(M.A')
Base.:*(M::Mat, u::AbstractVector) = M.A * u

function ChainRulesCore.rrule(::typeof(*), M::Mat, u::AbstractVector)
    project_u = ProjectTo(u)

    function pb(dv)
        du = @thunk(project_u( M' * dv ))
        dA = @thunk(project_p( dv * u' ))

        dM = Tangent{Mat}(;A=dA)

        NoTangent(), dM, du
    end

    M*u, pb
end

In this case, test_rrule attempts to perturb the boolean trait causing this error:

Random.seed!(0)
N = 8

M = Mat(rand(N,N), false)
u = rand(N)
test_rrule(*, M, u)
est_rrule: * on Mat{Matrix{Float64}},Vector{Float64}: Error During Test at /Users/vp
/.julia/packages/ChainRulesTestUtils/2VT4F/src/testers.jl:193                        
  Got exception outside of a @test                                                   
  TypeError: non-boolean (Float64) used in boolean context                           
  Stacktrace:                                                                        
    [1] macro expansion
      @ ~/.julia/packages/FiniteDifferences/VpgIT/src/to_vec.jl:0 [inlined]
    [2] _force_construct
      @ ~/.julia/packages/FiniteDifferences/VpgIT/src/to_vec.jl:27 [inlined]

...

  caused by: InexactError: Bool(-0.01)                                               
  Stacktrace:                                                                        
    [1] Bool                                                                         
      @ ./float.jl:158 [inlined]                                                     

...

Test Summary:                                         | Pass  Error  Total  Time
test_rrule: * on Mat{Matrix{Float64}},Vector{Float64} |    3      1      4  5.3s
ERROR: LoadError: Some tests did not pass: 3 passed, 0 failed, 1 errored, 0 broken.
in expression starting at /Users/vp/.julia/dev/smo_adj/test/f3.jl:46
oxinabox commented 2 years ago

I think if we made to_vec on Bool do the empty vector that would fix this.

But the short term solution is to definite the to_vec for Mat in your package/tests