I was wondering whether the ChainRules.Core.rrule for convdefined here is actually correct?
The following test against finite differencing using ChainRulesTestUtils gives large errors:
x = randn(ComplexF64, (5,6))
y = randn(ComplexF64, (5,6))
test_rrule(
Zygote.ZygoteRuleConfig(),
conv,
x,
y;
rrule_f=Zygote.rrule_via_ad
)
Of course finite differencing is not the gold standard, but if I delete the custom rrule, Zygote uses the rrules for fft and ifft (as defined in AbstractFFTs.jl) and that test passes just fine.
What is the motivation for a custom rule here anyway? Note that just using AbstractFFTs would give you also thunking/unthunking, correct ProjectTos and also tangents for the second argument.
Hi,
I was wondering whether the
ChainRules.Core.rrule
forconv
defined here is actually correct?The following test against finite differencing
using ChainRulesTestUtils
gives large errors:Of course finite differencing is not the gold standard, but if I delete the custom
rrule
, Zygote uses therrule
s forfft
andifft
(as defined in AbstractFFTs.jl) and that test passes just fine.What is the motivation for a custom rule here anyway? Note that just using AbstractFFTs would give you also thunking/unthunking, correct
ProjectTo
s and also tangents for the second argument.