bionanoimaging / FourierTools.jl

Tools for working with Fourier space.
https://bionanoimaging.github.io/FourierTools.jl/stable/
MIT License
56 stars 6 forks source link

Incorrect `rrule` for `conv`? #35

Closed trahflow closed 8 months ago

trahflow commented 1 year ago

Hi,

I was wondering whether the ChainRules.Core.rrule for conv defined here is actually correct?

The following test against finite differencing using ChainRulesTestUtils gives large errors:

x = randn(ComplexF64, (5,6))
y = randn(ComplexF64, (5,6))
test_rrule(
    Zygote.ZygoteRuleConfig(),
    conv,
    x,
    y;
    rrule_f=Zygote.rrule_via_ad
)

Of course finite differencing is not the gold standard, but if I delete the custom rrule, Zygote uses the rrules for fft and ifft (as defined in AbstractFFTs.jl) and that test passes just fine.

What is the motivation for a custom rule here anyway? Note that just using AbstractFFTs would give you also thunking/unthunking, correct ProjectTos and also tangents for the second argument.

roflmaostc commented 1 year ago

Hi!

You're right. The reason back then was performance. Zygote does not support thunk/unthunk.

But I'm using thst definition nowhere, so we should fix it.