kymatio / murenn

Multi-Resolution Neural Networks
MIT License
11 stars 1 forks source link

dtcwt shift invariant demo #61

Open xir4n opened 1 month ago

xir4n commented 1 month ago

I think it would be nice to have some demonstrations of the shift-invariant property of the dual-tree wavelet transform, as described in Chapter 4 of this paper: Complex Wavelets for Shift Invariant Analysis and Filtering of Signals

What I want to do is to make an animation for the shifted impulse responses of different subbands, like this:

def get_animation(j, **kwargs):
    '''
    This animation shows the subband component of a reconstructed step signal 
    when the input signal is shifting.
    Parameters:
    - j: int, the subband to consider. -1 implies lowpass subband.
    - **kwargs: additional arguments to pass to the DTCWT and IDTCWT.
    '''
    # Create a dirac signal
    T = 2**10
    wt = murenn.DTCWT(**kwargs)
    iwt = murenn.IDTCWT(**kwargs)

    fig, ax = plt.subplots()
    line_ref, = ax.plot([], [], ls='--', color='gray', label="input")
    line, = ax.plot([], [], lw=2, label='output')
    ax.set_xlim(0, T)
    ax.set_ylim(-1, 1.3)
    ax.set_title(f'Level {j}')
    ax.grid(ls='--')
    ax.legend(loc='upper right')

    def animate(tau):
        # Shift the signal
        x = torch.zeros(T)
        x[(T//2+tau):] = 1
        x = x.reshape(1, 1, T)
        # Apply the transform for the subband j
        lp, bps = wt(x)
        if j == -1:
            rec_x = iwt(lp, [bps[k]*0 for k in range(len(bps))])
        else:
            rec_x = iwt(lp*0, [bps[k]*0 if k != j else bps[k] for k in range(len(bps))])
        # Plot the result
        line_ref.set_data(range(T), x.squeeze().numpy())
        line.set_data(range(T), rec_x.squeeze().numpy())
        return line,

    ani = FuncAnimation(fig, animate, frames=range(0,T//4,1), init_func=init, blit=True)
    return ani
lostanlen commented 1 month ago

this is a good idea! thanks