Closed mattpitkin closed 7 months ago
Note: in my version I've also fixed #13 within the forward transform.
Nice! Thanks for this. I wonder if I fix this and re-train the neural network if it would affect the final performance.
I need a little bit more help understanding your fix. There's a confusing intermixing of some of my own optimizations (for the sake of efficiency in the middle of the training loop on the GPU), e.g. dropping the creation of temp0
, and your fix which brings some of those original nsgt structures back.
To set the stage (and related to https://github.com/sevagh/nsgt/issues/3 as you proposed), my ragged_vs_matrix.py
script which measures reconstruction error shows how the recon error is pretty high for small bins:
6 bins: ~1e-06 reconstruction error
(nsgt-torch) sevagh@pop-os:~/repos/nsgt$ python examples/ragged_vs_matrix.py --sr 44100 \
--scale=cqlog --fmin 83.0 --fmax 22050 --bins 6 \
./gspi_stereo_short.wav | grep mse
recon error (mse): 1.6235659359153942e-06
This is bad/buggy/incorrect, right? This is what you're fixing? It's not some fact of the nature of the nsgt?
Rerun with 60 bins: ~1e-16 reconstruction error
(nsgt-torch) sevagh@pop-os:~/repos/nsgt$ python examples/ragged_vs_matrix.py --sr 44100 \
--scale=cqlog --fmin 83.0 --fmax 22050 --bins 60 \
./gspi_stereo_short.wav | grep mse
recon error (mse): 1.5521547886190033e-16
So it would make sense that the 262 bins of this codebase mask this subtle error. But, with your proposed fix, there would be very low reconstruction error even at small amounts of frequency bins?
rr = 1 if freq_idx == 0 or freq_idx == nfreqs - 1 else 2
Or is it the other way around, is the special case handling you added for every bin except the first and last?t
to be confusing: t[:, :, 0].unsqueeze(2)
, torch.flip(t[:, :, 1:], dims=(2,)),
- the last dimension of t is the time bin dimension of the NSGT (which is ragged and varies per frequency bin as per the transform, as you know). Are we really concatenating the 0th time bin with the flip of the 1+th time bin?Hi @sevagh, the main thing that my fix tries to do is copy what happens in the symm
function here https://github.com/grrrr/nsgt/blob/master/nsgt/nsigtf.py#L96 (and https://github.com/grrrr/nsgt/blob/master/nsgt/nsigtf.py#L88).
I think, like you, I do see that the reconstruction error is worse for smaller numbers of bins (or if the slice lengths are too small compared to the suggested lengths), and my change does fix both those cases.
The special case handling that you added is only for the first and last frequency bin? rr = 1 if freq_idx == 0 or freq_idx == nfreqs - 1 else 2 Or is it the other way around, is the special case handling you added for every bin except the first and last?
For the first frequency bin it only goes into the:
if not k:
t = fc[:, :, j]
part of the if statement, but for all others it goes into both parts of the if...else...
statement.
I find the indexing of t to be confusing: t[:, :, 0].unsqueeze(2), torch.flip(t[:, :, 1:], dims=(2,)), - the last dimension of t is the time bin dimension of the NSGT (which is ragged and varies per frequency bin as per the transform, as you know). Are we really concatenating the 0th time bin with the flip of the 1+th time bin?
I'll have to double check this. I'll get back to you tomorrow.
I find the indexing of t to be confusing: t[:, :, 0].unsqueeze(2), torch.flip(t[:, :, 1:], dims=(2,)), - the last dimension of t is the time bin dimension of the NSGT (which is ragged and varies per frequency bin as per the transform, as you know). Are we really concatenating the 0th time bin with the flip of the 1+th time bin?
I've checked this and it seems to be correct and consistent with what is happening in the original NSGT implementation. My comment in the code
# use upper half frequencies (required to get perfect reconstruction)
may be wrong though!
Amazing, thanks for the extra context. I'll keep working on it.
OK, cool. I made some progress (in the other nsgt repo - when that's fixed I will revisit this one).
Here's the branch: https://github.com/sevagh/nsgt/tree/missing-fftsym-bug
examples/nsgt_orig
(the original code I kept around for some testing), I did an early return in the symm
lambda, to skip the fftsym step: https://github.com/sevagh/nsgt/commit/a315e4f760702aeef65d123731afe2ba884298bc#diff-4c8bbed6a03e57879f034a8e454f3466bf0ea694ce51ab38fae9cc45f46af092The result is excellent - the reconstruction error is identical when I intentionally disable the symm
lambda in the original code:
recon error (mse): 1.119293990825554e-08
ORIGINAL recon error (mse): 1.119293143977482e-08
When I re-enable the correct symm
behavior in the original library, we see the real low error of the original transform:
recon error (mse): 1.119293990825554e-08
ORIGINAL recon error (mse): 5.562861098466932e-34
So this is really the only missing thing in my nsgt library (xumx-sliCQ has the extra bug of the missing final block ifft in the forward transform as you noted in #13 but that's not present in the nsgt library)
OK, I think I figured it out. Your original fix was close. The only missing part was updating wr1, wr2, Lg
to use the values that correspond to the new inverted/flipped freq_idx
Code:
# frequencies are bucketed by same time resolution
fbin_ptr = 0
mfbin_ptr = len(loopparams)
for i, fc in enumerate(cseq):
nb_fbins = fc.shape[2]
temp0 = torch.empty(*cseq_shape[:2], maxLg, dtype=fr.dtype, device=torch.device(device))
for j, (wr1, wr2, Lg) in enumerate(loopparams[fbin_ptr : fbin_ptr + nb_fbins][:fbins]):
freq_idx = fbin_ptr + j
rr = 1 if freq_idx == 0 or freq_idx == nfreqs - 1 else 2
for k in range(rr):
# the overlap-add procedure including multiplication with the synthesis windows
t = fc[:, :, j]
if k == 1:
mfbin_ptr -= 1
freq_idx = mfbin_ptr
t = torch.concatenate(
(
t[:, :, :1],
torch.flip(t[:, :, 1:], dims=(2,))
),
dim=2,
).conj()
# need new params corresponding to adjusted freq_idx
wr1, wr2, Lg = loopparams[freq_idx]
r = (Lg + 1) // 2
l = Lg // 2
t1 = temp0[:, :, :r]
t2 = temp0[:, :, Lg - l : Lg]
t1[:, :, :] = t[:, :, :r]
t2[:, :, :] = t[:, :, Lg - l : Lg]
temp0[:, :, :Lg] *= gdiis[freq_idx, :Lg]
temp0[:, :, :Lg] *= Lg
fr[:, :, wr1] += t2
fr[:, :, wr2] += t1
fbin_ptr += nb_fbins
ftr = fr[:, :, :nn//2+1] if real else fr
sig = ifft(ftr, outn=nn)
sig = sig[:, :, :Ls] # Truncate the signal to original length (if given)
return sig
This definitely improved the reconstruction error of my transform - albeit not exactly matching the old code (but beating it in some cases :shrug: )
(nsgt-torch) sevagh@pop-os:~/repos/nsgt$ python examples/ragged_vs_matrix.py --sr 44100 --scale=cqlog --fmin 83.0 --fmax 22050 --bins 60 ~/repos/demucs.cpp/test/data/gspi_stereo_short.wav | grep mse/home/sevagh/repos/nsgt/examples/nsgt_orig/fft.py:116: UserWarning: nsgt.fft falling back to numpy.fft
warn("nsgt.fft falling back to numpy.fft")
recon error (mse): 1.5521547886190033e-16
ORIGINAL recon error (mse): 6.540821282082431e-34
(nsgt-torch) sevagh@pop-os:~/repos/nsgt$
(nsgt-torch) sevagh@pop-os:~/repos/nsgt$ python examples/ragged_vs_matrix.py --sr 44100 --scale=cqlog --fmin 83.0 --fmax 22050 --bins 30 ~/repos/demucs.cpp/test/data/gspi_stereo_short.wav | grep mse/home/sevagh/repos/nsgt/examples/nsgt_orig/fft.py:116: UserWarning: nsgt.fft falling back to numpy.fft
warn("nsgt.fft falling back to numpy.fft")
recon error (mse): 1.9574665389761884e-16
ORIGINAL recon error (mse): 8.221279256613032e-34
(nsgt-torch) sevagh@pop-os:~/repos/nsgt$ python examples/ragged_vs_matrix.py --sr 44100 --scale=cqlog --fmin 83.0 --fmax 22050 --bins 20 ~/repos/demucs.cpp/test/data/gspi_stereo_short.wav | grep mse/home/sevagh/repos/nsgt/examples/nsgt_orig/fft.py:116: UserWarning: nsgt.fft falling back to numpy.fft
warn("nsgt.fft falling back to numpy.fft")
recon error (mse): 9.865468837287191e-17
ORIGINAL recon error (mse): 1.734286077074739e-08
Would you consider this fixed, or is there further debugging to do to get the hermitian symmetry exactly matching the old code?
Ah, yes! Sorry, about missing the switch to the correct wr1
, wr2
and Lg
values (in my actual implementation, I'd got rid of filling in and looping over the loopparams
list of parameters and instead accessed everything directly through the correct index - so when converting back to getting the values out of loopparams
for this issue I'd missed the fact that they were not indexed correctly).
I'd consider this fixed I think 😃! I'd be grateful if you could somehow reference me in the commit when you merge this in, but it's no major issue if you can't.
100%, I'll ask you to create a PR when the fix is ready. Also, considering on whether the final performance (4.4 dB SDR on the test set) is influenced by fixing this bug, maybe a quick paper on arxiv
I have two combinations now to test:
OK, to be honest I don't see much differences. It trains worse with the fix (so omitting the ifft of the final block of the nsgtf is a good idea?) but performs with +0.002 dB SDR with the fixed nsigft.
Worth a try anyway, and thanks again!
You can contribute the following to nsgtf/nsigft in a branch? I'll merge it then add the rest of my code:
diff --git a/xumx_slicq_v2/nsgt/nsgtf.py b/xumx_slicq_v2/nsgt/nsgtf.py
index 54b5763..fcfe026 100644
--- a/xumx_slicq_v2/nsgt/nsgtf.py
+++ b/xumx_slicq_v2/nsgt/nsgtf.py
@@ -77,5 +77,8 @@ def nsgtf_sl(f_slices, g, wins, nn, M=None, real=False, reducedform=0, device="c
[bucketed_tensors[block_ptr], c], dim=2
)
+ # run an ifft on the last bucket
+ bucketed_tensors[-1] = torch.fft.ifft(bucketed_tensors[-1])
+
# bucket-wise ifft
return bucketed_tensors
diff --git a/xumx_slicq_v2/nsgt/nsigtf.py b/xumx_slicq_v2/nsgt/nsigtf.py
index 920136b..6719ee4 100644
--- a/xumx_slicq_v2/nsgt/nsigtf.py
+++ b/xumx_slicq_v2/nsgt/nsigtf.py
@@ -48,31 +48,54 @@ def nsigtf_sl(cseq, gd, wins, nn, Ls=None, real=False, reducedform=0, device="cp
# frequencies are bucketed by same time resolution
fbin_ptr = 0
- for i, fc in enumerate(cseq):
- Lg_outer = fc.shape[-1]
+ mfbin_ptr = len(loopparams)
+ for i, fc in enumerate(cseq):
nb_fbins = fc.shape[2]
- for j, (wr1, wr2, Lg) in enumerate(
- loopparams[fbin_ptr : fbin_ptr + nb_fbins][:fbins]
- ):
+
+ temp0 = torch.empty(*cseq_shape[:2], maxLg, dtype=fr.dtype, device=device)
+
+ for j, (wr1, wr2, Lg) in enumerate(loopparams[fbin_ptr : fbin_ptr + nb_fbins][:fbins]):
freq_idx = fbin_ptr + j
- assert Lg == Lg_outer
- t = fc[:, :, j]
+ rr = 1 if freq_idx == 0 or freq_idx == nfreqs - 1 else 2
- r = (Lg + 1) // 2
- l = Lg // 2
+ for k in range(rr):
+ # the overlap-add procedure including multiplication with the synthesis windows
+ t = fc[:, :, j]
- t1 = t[:, :, :r]
- t2 = t[:, :, Lg - l : Lg]
+ if k == 1:
+ mfbin_ptr -= 1
+ freq_idx = mfbin_ptr
- t[:, :, :Lg] *= gdiis[freq_idx, :Lg]
- t[:, :, :Lg] *= Lg
+ t = torch.concatenate(
+ (
+ t[:, :, 1:],
+ torch.flip(t[:, :, 1:], dims=(2,)),
+ ),
+ dim=2,
+ ).conj()
- fr[:, :, wr1] += t2
- fr[:, :, wr2] += t1
- fbin_ptr += nb_fbins
+ # need new params corresponding to adjusted freq_idx
+ wr1, wr2, Lg = loopparams[freq_idx]
+
+ r = (Lg + 1) // 2
+ l = Lg // 2
+ t1 = temp0[:, :, :r]
+ t2 = temp0[:, :, Lg - l : Lg]
+
+ t1[:, :, :] = t[:, :, :r]
+ t2[:, :, :] = t[:, :, Lg - l : Lg]
+
+ temp0[:, :, :Lg] *= gdiis[freq_idx, : Lg]
+ temp0[:, :, :Lg] *= Lg
+
+ fr[:, :, wr1] += t2
+ fr[:, :, wr2] += t1
+
+ fbin_ptr += nb_fbins
+
ftr = fr[:, :, : nn // 2 + 1] if real else fr
You can see them here: https://github.com/sevagh/xumx-sliCQ/commit/ebde9ced9fbc068ab829173a942394ff928ca4b7
Would you also want to contribute to my nsgt fork repo, or is xumx-sliCQ enough? I will eventually port the fix there, too.
Thanks, I've opened an PR with this patch implemented.
Great, it's now merged. Thanks again.
I've found that the implementation of the inverse NSGT (when not using the matrix form) does not give perfect reconstruction compared to the original NSGT implementation - although the values are close. I've been able to fix this by also including the conjugate of the FFTs in the inversion. In https://github.com/sevagh/xumx-sliCQ/blob/v2/xumx_slicq_v2/nsgt/nsigtf.py this changes:
to:
This may be related to https://github.com/sevagh/nsgt/issues/3.