neurodsp-tools / neurodsp

Digital signal processing for neural time series.
https://neurodsp-tools.github.io/
Apache License 2.0
281 stars 61 forks source link

Wavelet discussion, including normalization #165

Open TomDonoghue opened 5 years ago

TomDonoghue commented 5 years ago

Issue for discussion of our wavelets.

Normalization

For wavelets, right now we have two normalizations available:

* 'sss' - divide by the square root of the sum of squares
* 'amp' - divide by the sum of amplitudes

And these are implemented, as written in the code. The question is that, somewhere through the history, the thing fell out of sync, and the 'amp' used to say, in the docs, that this was all divided by two, even though the code did not do this.

What is somewhat unclear (to me, at least) is what the best wavelet normalizations are. Maybe it should be divided by two? So if anyone knows the secrets of wavelets, please throw out some info :).

TomDonoghue commented 4 years ago

@elybrand - since you've just been looking at timefrequency, I thought I would tag you in here, as you might know the answer, and if there is any problem here, or if not, then we can close this issue.

The current wavelets have a couple possible normalizations. Basically, at some point we ended up with a discrepancy between the docs and the code. Without trying to dig out all the history of the code, I wondered if you had any quick thoughts on this, and on whether there is any reason normalizing by the amplitudes should be divided by 2 (which we don't currently do, but the docs used to say we did).

elybrand commented 4 years ago

@TomDonoghue

So I have a couple of remarks. First, I just realized there is no way the user can specify the normalization method when they call compute_wavelet_transform. It just defaults to 'sss'!

Second, I don't immediately see why you would need to divide by 2. My understanding of the two normalizations is as follows:

L2 normalization If you were to select a wavelet transform where the wavelets you used formed a basis (effectively, you'd be doing a discrete wavelet transform), normalizing by the L2 norm ('sss') would preserve the energy of the signal like the normalized DFT. Otherwise, continuous wavelet transforms, normalized or not with 'sss', do not preserve energy (numerically). Graphically this does have an affect on the TF plot. Wavelets at scale s have an L2 norm of 1/sqrt(s). Consequentially, if your signal were sig = sin(5x) + sin(10x), though both frequencies have equal amplitude, the magnitude of the wavelet transform at frequency 5Hz will be larger than that at 10Hz because the normalization factor. While it appears that the scaling, and therefore the norm, is fixed for the implementation of compute_wavelet_transform as is, it's actually not in practice. This is obfuscated by the fact that the wavelet length is dependent on which frequency convolve_wavelet is fed in. The morlet wavelet constructed in convolve_wavelet isn't actually the length of the signal. I imagine you did this to truncate the Gaussian envelope to zero.

L1 normalization L1 normalization avoids attenuating the amplitude of the wavelet transform at higher frequencies. Going back to sig = sin(5x) + sin(10x), looking at the TF plot using norm='amp' gives you roughly equal peaks at 5Hz and 10Hz. Here is a minimal working example that demonstrates all of this (modulo I fudged compute_wavelet_transform to add the key word argument "norm". This should be a PR.)

from neurodsp import timefrequency as tf
import numpy as np

from neurodsp.utils import create_times
from neurodsp.sim import sim_combined

from neurodsp.plts.spectral import *
from neurodsp.plts.time_series import plot_time_series
import matplotlib.pyplot as plt

# Set the random seed, for consistency simulating data
np.random.seed(0)

# General setting for simulations
fs = 1000
n_seconds = 5

# Generate a times vector, for plotting
times = create_times(n_seconds, fs)

# Set the frequency in our simulated signal
freq = 6

# Set up simulation for a signal with an oscillaton + noise
components = {'sim_powerlaw' : {'exponent' : 0},
              'sim_oscillation' : {'freq' : 5}}
components2 = {'sim_powerlaw' : {'exponent' : 0},
              'sim_oscillation' : {'freq' : 10}}
variances = [0, 1]

# Simulate our signal with frequencies at 5 and 10 Hz.
sig = sim_combined(n_seconds, fs, components, variances) + sim_combined(n_seconds, fs, components2, variances)

# L2 normalized wavelets. Notice that the TF plot at frequency 10 is lower than that at 5.
fig, axes = plt.subplots(2,1, figsize=(8,7))
wt = tf.compute_wavelet_transform(sig, fs, [4,11,1], norm='sss')
axes[0].pcolormesh(times, np.arange(4,12), np.abs(wt.T), cmap='viridis', shading='gouraud')
axes[0].set_title("L2 Normalized Wavelet Transform")

# L1 normalized wavelets. Now the two frequencies in the TF plot are roughly the same.
wt2 = tf.compute_wavelet_transform(sig, fs, [4,11,1], norm='amp')
axes[1].pcolormesh(times, np.arange(4,12), np.abs(wt2.T), cmap='viridis', shading='gouraud')
axes[1].set_title("L1 Normalized Wavelet Transform")
TomDonoghue commented 4 years ago

For findability, and attaching to the codebase, I'm copying in some notes from @elybrand:

There is no way to specify the wavelet length from compute_wavelet_transform , and yet this is a kwarg for convolve_wavelet. Is this intentional? This precludes, for example, computing a convolution of a low-frequency wavelet which is defined over the entire signal unless n_cycles is defined appropriately.

Let's suppose that we add wavelet_len as a kwarg to compute_wavelet_transform.

Then I'd like to test the validity of compute_wavelet_transform by choosing certain frequencies and transforming a dirac spike since this will let me read off the wavelets directly. However, if I were to set wavelet_len = len(sig), then wavelets with high frequencies may be clipped (i.e. their tails are truncated to zero) by the default argument n_cycles=7. In other words, certain choices of wavelet_len, n_cycles, and freqs prevent one from getting the full morlet wavelet as defined over the length of the entire signal.

And replies / notes by @rdgao:

I think n_cycles=7 is legacy default taken from FIR filters. basically, there's some rule of thumb somewhere that 7 cycles of the center frequency (of that wavelet/filter) is the magic balance between T-F resolution tradeoff, and my guess it was ported over here based on that.

The thought process as far as I understand it for the DWT we have is just to treat each wavelet as the same as an FIR filter, with the additional constraint that there is some relationship between the wavelet frequencies. This is the predominant use case, in most neuroscience applications, for computing a time-frequency representation.

The reply back from @elybrand is to say that the FIR interpretation of DWT is reasonable. Mathematicians use the intuition of thinking of DWT as bandpass filters. Given the use case, then the implementation should broadly be okay.

TomDonoghue commented 4 years ago

Thanks a lot to @elybrand for detailed analysis and discussion here.

I'm going to try to summarize and see if we can organize if & what the ToDo items are.

Notes & questions:

Perhaps relevant context: I would say it's much common to use something like wavelet-timefrequency to examine changes across time, within frequencies. I don't think explicit comparisons of absolute power between frequencies is common, so while we want to represent these estimations as appropriately as possible, I can't say I'm totally worried about this exact quantitative comparison (and in real data, the whole 1/f thing gets in the way here anyways).

Thoughts @elybrand and/or @rdgao

elybrand commented 4 years ago

@TomDonoghue This is a great summary. For now, I agree it is best to not add functionality for setting wavelet_len. The normalizations right now are also correctly implemented.

I do want to add a semantic point which adds to the discussion on normalization. Technically speaking compute_wavelet_transform is not computing a discrete wavelet transform. It is computing a discretized continuous wavelet transform. MATLAB has some really nice documentation explaining the difference in great detail, but the TL;DR is that the difference is in how you discretize the scaling and translation parameters.

Discrete Wavelet Transform (DWT)

Discrete wavelet transforms discretize the scale and location parameters of the wavelets in a pre-determined way. They use dyadic scales (powers of 2) for the scale, and the translations of the wavelets are always integer multiples of the scale. There are certain benefits to this, namely the DWT preserves energy and minimizes redundancy of the "information content" in the wavelet coefficients. That's because the wavelets they use are orthogonal. This is a pretty stringent requirement, and not every continuous wavelet admits a DWT. Morelet wavelets, for example, don't admit a DWT. The downside to DWTs is that your discretization of time-frequency space is pretty coarse.

Discretized Continuous Wavelet Transform (dCWT)

On the other hand, a discretized continuous wavelet transform allows you to choose an arbitrary way of discretizing the time-frequency plane. The only thing that's discretized is the signal itself. You let the user take advantage of this when you let them feed in an array of frequencies. As I mentioned above, dCWT don't preserve energy. Consequentially, I believe the standard practice is to use 'norm='amp'' so that all frequencies/scales are normalized equally.

TomDonoghue commented 4 years ago

Okay, cool, thanks for the info! Looking into and figuring out out wavelets has always been on my ToDo list, so this is a super helpful primer!

It seems to me the ToDo items then are to:

rdgao commented 4 years ago

so I actually had it in me that we've been doing DWT this whole time with predetermined scales and preserving signal energy, but just realized that makes no sense if we allow arbitrary specification of frequencies. herp derp berp.

is it accurate to say that DWT is a subset of dCWT with more stringent parameter configurations, such that one could use compute_wavelet_transform() as it is now with the right mother wavelet & scale parameters to implement DWT? And otherwise by default, it just behaves like multiple FIR filters? If so, maybe a note in the documentation or an example tutorial on how to get to DWT would be useful, as I think most papers I've read that use wavelets show a DWT spectrogram, with dyadic (or some other multiple) scales.

elybrand commented 4 years ago

is it accurate to say that DWT is a subset of dCWT with more stringent parameter configurations

Definitely.

such that one could use compute_wavelet_transform() as it is now with the right mother wavelet & scale parameters to implement DWT?

Not quite. The discrete wavelet transform goes beyond discretizing which frequencies you look at. It also discretizes the translations of the wavelets. The current implementation of compute_wavelet_transform returns the entire convolution. A DWT would downsample the convolution at dyadic rates.

I find the wikipedia article of wavelet transform does a nice job of representing graphically what a DWT does. You can see that a DWT tiles the time frequency plane in such a way that the tiles do not overlap but still cover the space. That's what I mean when I say a DWT "minimizes redundancy".

When I first learned about wavelets, I started with the discrete Haar wavelet transform. There, it's very obvious what's going on with the translations and dilations. Thinking about in vector language, a DWT is constructing an orthonormal basis of wavelets. The DWT then is just taking dot products of the signal with fixed wavelets that have been dilated and shifted at particular locations. A dCWT guarantees neither orthogonality nor that the wavelets form a basis.

rdgao commented 4 years ago

right, my understanding is that DWT uses a strictly orthogonal basis, but you have to do something funny with the sampling at each scale as well?

elybrand commented 4 years ago

Precisely!