sevagh / xumx-sliCQ

music demixing with the sliCQ Transform and PyTorch
MIT License
24 stars 6 forks source link

Non-matrix form NSGT does not give perfect inversion #14

Closed mattpitkin closed 7 months ago

mattpitkin commented 8 months ago

I've found that the implementation of the inverse NSGT (when not using the matrix form) does not give perfect reconstruction compared to the original NSGT implementation - although the values are close. I've been able to fix this by also including the conjugate of the FFTs in the inversion. In https://github.com/sevagh/xumx-sliCQ/blob/v2/xumx_slicq_v2/nsgt/nsigtf.py this changes:

    # frequencies are bucketed by same time resolution
    fbin_ptr = 0
    for i, fc in enumerate(cseq):
        Lg_outer = fc.shape[-1]

        nb_fbins = fc.shape[2]
        for j, (wr1, wr2, Lg) in enumerate(
            loopparams[fbin_ptr : fbin_ptr + nb_fbins][:fbins]
        ):
            freq_idx = fbin_ptr + j
            assert Lg == Lg_outer

            t = fc[:, :, j]

            r = (Lg + 1) // 2
            l = Lg // 2

            t1 = t[:, :, :r]
            t2 = t[:, :, Lg - l : Lg]

            t[:, :, :Lg] *= gdiis[freq_idx, :Lg]
            t[:, :, :Lg] *= Lg

            fr[:, :, wr1] += t2
            fr[:, :, wr2] += t1
        fbin_ptr += nb_fbins

to:

        fbin_ptr = 0
        mfbin_ptr = len(loopparams)
        for i, fc in enumerate(cseq):
            nb_fbins = fc.shape[2]

            temp0 = torch.empty(
                *cseq_shape[:2], maxLg, dtype=fr.dtype, device=self.device
            )  # pre-allocation

            for j, (wr1, wr2, Lg) in enumerate(loopparams[fbin_ptr : fbin_ptr + nb_fbins][:fbins]):
                freq_idx = fbin_ptr + j

                rr = 1 if freq_idx == 0 or freq_idx == nfreqs - 1 else 2

                for k in range(rr):
                    # the overlap-add procedure including multiplication with the synthesis windows
                    if not k:
                        t = fc[:, :, j]
                    else:
                        # use upper half frequencies (required to get perfect reconstruction)
                        mfbin_ptr -= 1
                        freq_idx = mfbin_ptr
                        t = fc[:, :, j]
                        t = torch.concatenate(
                            (
                                t[:, :, 0].unsqueeze(2),
                                torch.flip(t[:, :, 1:], dims=(2,)),
                            ),
                            dim=2,
                        ).conj()

                    r = (Lg + 1) // 2
                    l = Lg // 2

                    t1 = temp0[:, :, :r]
                    t2 = temp0[:, :, Lg - l : Lg]

                    t1[:, :, :] = t[:, :, :r]
                    t2[:, :, :] = t[:, :, Lg - l : Lg]

                    temp0[:, :, :Lg] *= gdiis[freq_idx, : Lg]
                    temp0[:, :, :Lg] *= Lg

                    fr[:, :, wr1] += t2
                    fr[:, :, wr2] += t1

            fbin_ptr += nb_fbins

This may be related to https://github.com/sevagh/nsgt/issues/3.

mattpitkin commented 8 months ago

Note: in my version I've also fixed #13 within the forward transform.

sevagh commented 7 months ago

Nice! Thanks for this. I wonder if I fix this and re-train the neural network if it would affect the final performance.

sevagh commented 7 months ago

I need a little bit more help understanding your fix. There's a confusing intermixing of some of my own optimizations (for the sake of efficiency in the middle of the training loop on the GPU), e.g. dropping the creation of temp0, and your fix which brings some of those original nsgt structures back.

To set the stage (and related to https://github.com/sevagh/nsgt/issues/3 as you proposed), my ragged_vs_matrix.py script which measures reconstruction error shows how the recon error is pretty high for small bins:

6 bins: ~1e-06 reconstruction error

(nsgt-torch) sevagh@pop-os:~/repos/nsgt$ python examples/ragged_vs_matrix.py --sr 44100 \
--scale=cqlog --fmin 83.0 --fmax 22050 --bins 6 \
./gspi_stereo_short.wav | grep mse
recon error (mse): 1.6235659359153942e-06

This is bad/buggy/incorrect, right? This is what you're fixing? It's not some fact of the nature of the nsgt?

Rerun with 60 bins: ~1e-16 reconstruction error

(nsgt-torch) sevagh@pop-os:~/repos/nsgt$ python examples/ragged_vs_matrix.py --sr 44100 \
--scale=cqlog --fmin 83.0 --fmax 22050 --bins 60 \
./gspi_stereo_short.wav | grep mse
recon error (mse): 1.5521547886190033e-16

So it would make sense that the 262 bins of this codebase mask this subtle error. But, with your proposed fix, there would be very low reconstruction error even at small amounts of frequency bins?

  1. The special case handling that you added is only for the first and last frequency bin? rr = 1 if freq_idx == 0 or freq_idx == nfreqs - 1 else 2 Or is it the other way around, is the special case handling you added for every bin except the first and last?
  2. I find the indexing of t to be confusing: t[:, :, 0].unsqueeze(2), torch.flip(t[:, :, 1:], dims=(2,)), - the last dimension of t is the time bin dimension of the NSGT (which is ragged and varies per frequency bin as per the transform, as you know). Are we really concatenating the 0th time bin with the flip of the 1+th time bin?
mattpitkin commented 7 months ago

Hi @sevagh, the main thing that my fix tries to do is copy what happens in the symm function here https://github.com/grrrr/nsgt/blob/master/nsgt/nsigtf.py#L96 (and https://github.com/grrrr/nsgt/blob/master/nsgt/nsigtf.py#L88).

I think, like you, I do see that the reconstruction error is worse for smaller numbers of bins (or if the slice lengths are too small compared to the suggested lengths), and my change does fix both those cases.

The special case handling that you added is only for the first and last frequency bin? rr = 1 if freq_idx == 0 or freq_idx == nfreqs - 1 else 2 Or is it the other way around, is the special case handling you added for every bin except the first and last?

For the first frequency bin it only goes into the:

                    if not k:
                        t = fc[:, :, j]

part of the if statement, but for all others it goes into both parts of the if...else... statement.

I find the indexing of t to be confusing: t[:, :, 0].unsqueeze(2), torch.flip(t[:, :, 1:], dims=(2,)), - the last dimension of t is the time bin dimension of the NSGT (which is ragged and varies per frequency bin as per the transform, as you know). Are we really concatenating the 0th time bin with the flip of the 1+th time bin?

I'll have to double check this. I'll get back to you tomorrow.

mattpitkin commented 7 months ago

I find the indexing of t to be confusing: t[:, :, 0].unsqueeze(2), torch.flip(t[:, :, 1:], dims=(2,)), - the last dimension of t is the time bin dimension of the NSGT (which is ragged and varies per frequency bin as per the transform, as you know). Are we really concatenating the 0th time bin with the flip of the 1+th time bin?

I've checked this and it seems to be correct and consistent with what is happening in the original NSGT implementation. My comment in the code

# use upper half frequencies (required to get perfect reconstruction)

may be wrong though!

sevagh commented 7 months ago

Amazing, thanks for the extra context. I'll keep working on it.

sevagh commented 7 months ago

OK, cool. I made some progress (in the other nsgt repo - when that's fixed I will revisit this one).

Here's the branch: https://github.com/sevagh/nsgt/tree/missing-fftsym-bug

  1. First, in my copy of examples/nsgt_orig (the original code I kept around for some testing), I did an early return in the symm lambda, to skip the fftsym step: https://github.com/sevagh/nsgt/commit/a315e4f760702aeef65d123731afe2ba884298bc#diff-4c8bbed6a03e57879f034a8e454f3466bf0ea694ce51ab38fae9cc45f46af092
  2. Then in the ragged_vs_matrix.py script (that I've been using to post reconstruction error), I added the recon error of the original transform: https://github.com/sevagh/nsgt/commit/a315e4f760702aeef65d123731afe2ba884298bc#diff-00efd77ba09dd10db03cffe7010e6ae06785de729dd340205eba9c06712fca03R90

The result is excellent - the reconstruction error is identical when I intentionally disable the symm lambda in the original code:

recon error (mse): 1.119293990825554e-08
ORIGINAL recon error (mse): 1.119293143977482e-08

When I re-enable the correct symm behavior in the original library, we see the real low error of the original transform:

recon error (mse): 1.119293990825554e-08
ORIGINAL recon error (mse): 5.562861098466932e-34

So this is really the only missing thing in my nsgt library (xumx-sliCQ has the extra bug of the missing final block ifft in the forward transform as you noted in #13 but that's not present in the nsgt library)

sevagh commented 7 months ago

OK, I think I figured it out. Your original fix was close. The only missing part was updating wr1, wr2, Lg to use the values that correspond to the new inverted/flipped freq_idx Code:

        # frequencies are bucketed by same time resolution
        fbin_ptr = 0
        mfbin_ptr = len(loopparams)

        for i, fc in enumerate(cseq):
            nb_fbins = fc.shape[2]

            temp0 = torch.empty(*cseq_shape[:2], maxLg, dtype=fr.dtype, device=torch.device(device))

            for j, (wr1, wr2, Lg) in enumerate(loopparams[fbin_ptr : fbin_ptr + nb_fbins][:fbins]):
                freq_idx = fbin_ptr + j

                rr = 1 if freq_idx == 0 or freq_idx == nfreqs - 1 else 2

                for k in range(rr):
                    # the overlap-add procedure including multiplication with the synthesis windows
                    t = fc[:, :, j]

                    if k == 1:
                        mfbin_ptr -= 1
                        freq_idx = mfbin_ptr

                        t = torch.concatenate(
                            (
                                t[:, :, :1],
                                torch.flip(t[:, :, 1:], dims=(2,))
                            ),
                            dim=2,
                        ).conj()

                        # need new params corresponding to adjusted freq_idx
                        wr1, wr2, Lg = loopparams[freq_idx]

                    r = (Lg + 1) // 2
                    l = Lg // 2

                    t1 = temp0[:, :, :r]
                    t2 = temp0[:, :, Lg - l : Lg]

                    t1[:, :, :] = t[:, :, :r]
                    t2[:, :, :] = t[:, :, Lg - l : Lg]

                    temp0[:, :, :Lg] *= gdiis[freq_idx, :Lg]
                    temp0[:, :, :Lg] *= Lg

                    fr[:, :, wr1] += t2
                    fr[:, :, wr2] += t1

            fbin_ptr += nb_fbins

    ftr = fr[:, :, :nn//2+1] if real else fr
    sig = ifft(ftr, outn=nn)
    sig = sig[:, :, :Ls] # Truncate the signal to original length (if given)
    return sig

This definitely improved the reconstruction error of my transform - albeit not exactly matching the old code (but beating it in some cases :shrug: )

(nsgt-torch) sevagh@pop-os:~/repos/nsgt$ python examples/ragged_vs_matrix.py --sr 44100               --scale=cqlog --fmin 83.0 --fmax 22050 --bins 60 ~/repos/demucs.cpp/test/data/gspi_stereo_short.wav | grep mse/home/sevagh/repos/nsgt/examples/nsgt_orig/fft.py:116: UserWarning: nsgt.fft falling back to numpy.fft
  warn("nsgt.fft falling back to numpy.fft")
recon error (mse): 1.5521547886190033e-16
ORIGINAL recon error (mse): 6.540821282082431e-34
(nsgt-torch) sevagh@pop-os:~/repos/nsgt$
(nsgt-torch) sevagh@pop-os:~/repos/nsgt$ python examples/ragged_vs_matrix.py --sr 44100               --scale=cqlog --fmin 83.0 --fmax 22050 --bins 30 ~/repos/demucs.cpp/test/data/gspi_stereo_short.wav | grep mse/home/sevagh/repos/nsgt/examples/nsgt_orig/fft.py:116: UserWarning: nsgt.fft falling back to numpy.fft
  warn("nsgt.fft falling back to numpy.fft")
recon error (mse): 1.9574665389761884e-16
ORIGINAL recon error (mse): 8.221279256613032e-34
(nsgt-torch) sevagh@pop-os:~/repos/nsgt$ python examples/ragged_vs_matrix.py --sr 44100               --scale=cqlog --fmin 83.0 --fmax 22050 --bins 20 ~/repos/demucs.cpp/test/data/gspi_stereo_short.wav | grep mse/home/sevagh/repos/nsgt/examples/nsgt_orig/fft.py:116: UserWarning: nsgt.fft falling back to numpy.fft
  warn("nsgt.fft falling back to numpy.fft")
recon error (mse): 9.865468837287191e-17
ORIGINAL recon error (mse): 1.734286077074739e-08

Would you consider this fixed, or is there further debugging to do to get the hermitian symmetry exactly matching the old code?

mattpitkin commented 7 months ago

Ah, yes! Sorry, about missing the switch to the correct wr1, wr2 and Lg values (in my actual implementation, I'd got rid of filling in and looping over the loopparams list of parameters and instead accessed everything directly through the correct index - so when converting back to getting the values out of loopparams for this issue I'd missed the fact that they were not indexed correctly).

I'd consider this fixed I think 😃! I'd be grateful if you could somehow reference me in the commit when you merge this in, but it's no major issue if you can't.

sevagh commented 7 months ago

100%, I'll ask you to create a PR when the fix is ready. Also, considering on whether the final performance (4.4 dB SDR on the test set) is influenced by fixing this bug, maybe a quick paper on arxiv

sevagh commented 7 months ago

I have two combinations now to test:

  1. Old pretrained model with bugfixed new nsgt code
    • The fixed missing ifft of final block of the forward transform is a small change in the input (that wasn't present on the training inputs, which were all missing the final ifft from the buggy forward transform) - curious what happens
    • The inverse slicqt/nsgt is not present in the training loop, so I'm not too worried about how that is implicated in the trained model, but it would still affect the total SDR after all the inverse transforms
  2. New trained model using bugfixed nsgt as inputs
sevagh commented 7 months ago

OK, to be honest I don't see much differences. It trains worse with the fix (so omitting the ifft of the final block of the nsgtf is a good idea?) but performs with +0.002 dB SDR with the fixed nsigft.

Worth a try anyway, and thanks again!

You can contribute the following to nsgtf/nsigft in a branch? I'll merge it then add the rest of my code:

diff --git a/xumx_slicq_v2/nsgt/nsgtf.py b/xumx_slicq_v2/nsgt/nsgtf.py
index 54b5763..fcfe026 100644
--- a/xumx_slicq_v2/nsgt/nsgtf.py
+++ b/xumx_slicq_v2/nsgt/nsgtf.py
@@ -77,5 +77,8 @@ def nsgtf_sl(f_slices, g, wins, nn, M=None, real=False, reducedform=0, device="c
                 [bucketed_tensors[block_ptr], c], dim=2
             )

+    # run an ifft on the last bucket
+    bucketed_tensors[-1] = torch.fft.ifft(bucketed_tensors[-1])
+
     # bucket-wise ifft
     return bucketed_tensors
diff --git a/xumx_slicq_v2/nsgt/nsigtf.py b/xumx_slicq_v2/nsgt/nsigtf.py
index 920136b..6719ee4 100644
--- a/xumx_slicq_v2/nsgt/nsigtf.py
+++ b/xumx_slicq_v2/nsgt/nsigtf.py
@@ -48,31 +48,54 @@ def nsigtf_sl(cseq, gd, wins, nn, Ls=None, real=False, reducedform=0, device="cp

     # frequencies are bucketed by same time resolution
     fbin_ptr = 0
-    for i, fc in enumerate(cseq):
-        Lg_outer = fc.shape[-1]
+    mfbin_ptr = len(loopparams)

+    for i, fc in enumerate(cseq):
         nb_fbins = fc.shape[2]
-        for j, (wr1, wr2, Lg) in enumerate(
-            loopparams[fbin_ptr : fbin_ptr + nb_fbins][:fbins]
-        ):
+
+        temp0 = torch.empty(*cseq_shape[:2], maxLg, dtype=fr.dtype, device=device)
+
+        for j, (wr1, wr2, Lg) in enumerate(loopparams[fbin_ptr : fbin_ptr + nb_fbins][:fbins]):
             freq_idx = fbin_ptr + j
-            assert Lg == Lg_outer

-            t = fc[:, :, j]
+            rr = 1 if freq_idx == 0 or freq_idx == nfreqs - 1 else 2

-            r = (Lg + 1) // 2
-            l = Lg // 2
+            for k in range(rr):
+                # the overlap-add procedure including multiplication with the synthesis windows
+                t = fc[:, :, j]

-            t1 = t[:, :, :r]
-            t2 = t[:, :, Lg - l : Lg]
+                if k == 1:
+                    mfbin_ptr -= 1
+                    freq_idx = mfbin_ptr

-            t[:, :, :Lg] *= gdiis[freq_idx, :Lg]
-            t[:, :, :Lg] *= Lg
+                    t = torch.concatenate(
+                        (
+                            t[:, :, 1:],
+                            torch.flip(t[:, :, 1:], dims=(2,)),
+                        ),
+                        dim=2,
+                    ).conj()

-            fr[:, :, wr1] += t2
-            fr[:, :, wr2] += t1
-        fbin_ptr += nb_fbins
+                    # need new params corresponding to adjusted freq_idx
+                    wr1, wr2, Lg = loopparams[freq_idx]
+
+                r = (Lg + 1) // 2
+                l = Lg // 2

+                t1 = temp0[:, :, :r]
+                t2 = temp0[:, :, Lg - l : Lg]
+
+                t1[:, :, :] = t[:, :, :r]
+                t2[:, :, :] = t[:, :, Lg - l : Lg]
+
+                temp0[:, :, :Lg] *= gdiis[freq_idx, : Lg]
+                temp0[:, :, :Lg] *= Lg
+
+                fr[:, :, wr1] += t2
+                fr[:, :, wr2] += t1
+
+        fbin_ptr += nb_fbins
+
     ftr = fr[:, :, : nn // 2 + 1] if real else fr

You can see them here: https://github.com/sevagh/xumx-sliCQ/commit/ebde9ced9fbc068ab829173a942394ff928ca4b7

Would you also want to contribute to my nsgt fork repo, or is xumx-sliCQ enough? I will eventually port the fix there, too.

mattpitkin commented 7 months ago

Thanks, I've opened an PR with this patch implemented.

sevagh commented 7 months ago

Great, it's now merged. Thanks again.