kymatio / murenn

Multi-Resolution Neural Networks
MIT License
11 stars 1 forks source link

`MuReNNDirect`: Conv1D should be before abs2 #37

Closed lostanlen closed 6 months ago

lostanlen commented 7 months ago

compare with waspaa paper

    def forward(self, x):
        x = x.reshape(x.shape[0], 1, x.shape[-1])
        _, x_levels = self.tfm.forward(x)
        Ux = []

        for j_psi in range(1+self.J_psi):
            x_level = x_levels[j_psi].type(torch.complex64) / (2**j_psi)
            Wx_real = self.psis[j_psi](x_level.real)
            Wx_imag = self.psis[j_psi](x_level.imag)
            Ux_j = Wx_real * Wx_real + Wx_imag * Wx_imag
            Ux_j = torch.real(Ux_j)
            if j_psi == 0:
                N_j = Ux_j.shape[-1]
            else:
                Ux_j = Ux_j[:, :, :N_j]
            Ux.append(Ux_j)

        Ux = torch.cat(Ux, axis=1)

https://github.com/lostanlen/lostanlen2023waspaa/blob/main/student.py

lostanlen commented 6 months ago

fixed by #38