JuliaDiff / ChainRulesTestUtils.jl

Utilities for testing custom AD primitives.
MIT License
50 stars 15 forks source link

test_rrule bug? #119

Open mzgubic opened 3 years ago

mzgubic commented 3 years ago

From

https://github.com/JuliaDiff/ChainRules.jl/blob/24318b0321ccd48f16cbbd59dba6ae8bb9e90860/test/rulesets/LinearAlgebra/structured.jl#L120

using ChainRules
using ChainRulesCore
using ChainRulesTestUtils
using LinearAlgebra

f = adjoint
T = Float64
n = 5
m = 3

A = randn(T, n, m)
Y = f(A)
Ȳ_mat = randn(T, m, n)
Ȳ_composite = Composite{typeof(Y)}(parent=collect(f(Ȳ_mat)))

test_rrule(f, A; output_tangent=Ȳ_mat) # works
test_rrule(f, A; output_tangent=Ȳ_composite) # breaks
test_rrule(f, A) # breaks
willtebbutt commented 3 years ago

Could you provide a stack trace please? :)

mzgubic commented 3 years ago

Sure, sorry:

Test Failed at /Users/mzgubic/JuliaEnvs/ChainRules.jl/dev/ChainRulesTestUtils/src/check_result.jl:19
  Expression: isapprox(actual, expected; kwargs...)
   Evaluated: isapprox([-0.9813351316701018 -1.4155573915526176 -0.08802023134709079; 0.5821187406162006 0.25518397597220116 0.6440863483747852; … ; -0.8464238028962722 0.8760142296068198 -1.4627259246978286; 1.3112413776523713 -1.8132103745503358 -1.8489217478396944], [-0.9813351316700425 0.5821187406162589 0.8906072162724507; -0.846423802896213 1.311241377652429 -1.4155573915525579; … ; -1.8132103745503456 -0.08802023134703205 0.6440863483748435; -0.004548957389037324 -1.4627259246978253 -1.8489217478396347]; rtol = 1.0e-9, atol = 1.0e-9)

I looked into it yesterday and it seems that there is an extra adjoint somewhere which messes up the test (the rrule itself looks fine, gives the same result for both output_tangents)

willtebbutt commented 3 years ago

Oh weird, yeah, that sounds like a bug.