jonkhler / s2cnn

Spherical CNNs
MIT License
939 stars 176 forks source link

Theoretical Problems about SO(3) Fourier Transformation #62

Closed LeoDuhz closed 2 years ago

LeoDuhz commented 2 years ago

Hi,

Thanks for your great work! Your repo really helps me a lot! However, i am encountering some theoretical problems about the derivative of SO(3) fourier transformation when i tries to write the back propagation formula related to SO(3) fourier transformation. To be more specified, if i have

In your code, you write like this:

class SO3_fft_real(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, b_out=None):  # pylint: disable=W
        ctx.b_out = b_out
        ctx.b_in = x.size(-1) // 2
        return so3_rfft(x, b_out=ctx.b_out)

    @staticmethod
    def backward(self, grad_output):  # pylint: disable=W
        # ifft of grad_output is not necessarily real, therefore we cannot use rifft
        return so3_ifft(grad_output, for_grad=True, b_out=self.b_in)[..., 0], None

It seems very intuitive that the partial derivative of SO(3) fft is SO(3) ifft, but i still wonders how you can get this. Can you help to illustrate this in details or is there any paper deriving this formula or discussing about this?

Thanks a lot for your help!!

tscohen commented 2 years ago

Although it's implemented in a clever way, the FFT is ultimately just a linear transformation, that is furthermore orthogonal (in the real case) / unitary (complex). So you could write it as a matrix multiplication y = F x where x is the spatial signal, y is the spectral signal, and F is the Fourier matrix. The Jacobian dy/dx of a matrix-vector product function is just the transpose / conjugate transpose of F. Since F is orthogonal/unitary, this equals the ifft.

LeoDuhz commented 2 years ago

Although it's implemented in a clever way, the FFT is ultimately just a linear transformation, that is furthermore orthogonal (in the real case) / unitary (complex). So you could write it as a matrix multiplication y = F x where x is the spatial signal, y is the spectral signal, and F is the Fourier matrix. The Jacobian dy/dx of a matrix-vector product function is just the transpose / conjugate transpose of F. Since F is orthogonal/unitary, this equals the ifft.

i understand what you mean, thank you so much for your help!!