Closed lostanlen closed 6 months ago
compare with waspaa paper
def forward(self, x): x = x.reshape(x.shape[0], 1, x.shape[-1]) _, x_levels = self.tfm.forward(x) Ux = [] for j_psi in range(1+self.J_psi): x_level = x_levels[j_psi].type(torch.complex64) / (2**j_psi) Wx_real = self.psis[j_psi](x_level.real) Wx_imag = self.psis[j_psi](x_level.imag) Ux_j = Wx_real * Wx_real + Wx_imag * Wx_imag Ux_j = torch.real(Ux_j) if j_psi == 0: N_j = Ux_j.shape[-1] else: Ux_j = Ux_j[:, :, :N_j] Ux.append(Ux_j) Ux = torch.cat(Ux, axis=1)
https://github.com/lostanlen/lostanlen2023waspaa/blob/main/student.py
fixed by #38
compare with waspaa paper
https://github.com/lostanlen/lostanlen2023waspaa/blob/main/student.py