mcw519 / PureSound

Make the sound you hear pure and clean by deep learning.
7 stars 0 forks source link

Is it necessary to limit the version of pytoch, I get very poor results on all 4 ns models? #4

Closed zuowanbushiwo closed 1 year ago

zuowanbushiwo commented 1 year ago

inference results: image ns_data.zip

Package                   Version
------------------------- ------------
absl-py                   1.4.0
aiohttp                   3.8.4
aiosignal                 1.3.1
alembic                   1.9.4
antlr4-python3-runtime    4.8
appdirs                   1.4.4
asteroid                  0.6.0
asteroid-filterbanks      0.4.0
asttokens                 2.2.1
async-timeout             4.0.2
attrdict                  2.0.1
attrs                     22.2.0
audioread                 3.0.0
backcall                  0.2.0
backports.cached-property 1.0.2
brotlipy                  0.7.0
brouhaha                  0.9.0
cached-property           1.5.2
cachetools                5.3.0
certifi                   2022.12.7
cffi                      1.15.1
charset-normalizer        2.0.4
click                     8.1.3
cmaes                     0.9.1
colorama                  0.4.6
coloredlogs               15.0.1
colorlog                  6.7.0
comm                      0.1.2
commonmark                0.9.1
conda                     23.1.0
conda-package-handling    2.0.2
conda_package_streaming   0.7.0
contourpy                 1.0.7
cryptography              38.0.4
cycler                    0.11.0
Cython                    0.29.34
debugpy                   1.6.6
decorator                 5.1.1
DeepFilterDataLoader      0.4.0
DeepFilterLib             0.4.0
deepfilternet             0.4.0
docopt                    0.6.2
einops                    0.3.2
et-xmlfile                1.1.0
exceptiongroup            1.1.1
executing                 1.2.0
filelock                  3.9.0
flatbuffers               23.1.21
flit_core                 3.6.0
fonttools                 4.38.0
frozenlist                1.3.3
fsspec                    2023.1.0
future                    0.18.3
google-auth               2.16.1
google-auth-oauthlib      0.4.6
greenlet                  2.0.2
grpcio                    1.51.3
hmmlearn                  0.2.8
huggingface-hub           0.12.1
humanfriendly             10.0
Hydra                     2.5
hydra-core                1.1.0
HyperPyYAML               1.1.0
icecream                  2.1.3
idna                      3.4
importlib-metadata        6.0.0
importlib-resources       5.12.0
iniconfig                 2.0.0
ipyhton                   0.1
ipykernel                 6.21.2
ipython                   8.10.0
jedi                      0.18.2
joblib                    1.2.0
julius                    0.2.7
jupyter_client            8.0.3
jupyter_core              5.2.0
kiwisolver                1.4.4
librosa                   0.9.2
llvmlite                  0.39.1
loguru                    0.7.0
Mako                      1.2.4
Markdown                  3.4.1
MarkupSafe                2.1.2
matplotlib                3.7.0
matplotlib-inline         0.1.6
MinDAEC                   0.0.2
mir-eval                  0.7
mkl-fft                   1.3.1
mkl-random                1.2.2
mkl-service               2.4.0
mpmath                    1.2.1
multidict                 6.0.4
multiprocessing-logging   0.3.4
natsort                   8.3.1
nest-asyncio              1.5.6
networkx                  2.8.8
numba                     0.56.4
numpy                     1.23.5
oauthlib                  3.2.2
omegaconf                 2.1.2
onnx                      1.13.1
onnx-simplifier           0.4.17
onnxruntime               1.14.0
openpyxl                  3.1.2
openyxl                   0.1
optuna                    3.1.0
packaging                 23.0
pandas                    1.5.3
parso                     0.8.3
pb-bss-eval               0.0.2
pesq                      0.0.4
pexpect                   4.8.0
pickleshare               0.7.5
Pillow                    9.3.0
pip                       22.3.1
pipdeptree                2.5.0
platformdirs              3.0.0
pluggy                    1.0.0
pooch                     1.6.0
primePy                   1.3
prompt-toolkit            3.0.37
protobuf                  3.20.3
psutil                    5.9.4
ptyprocess                0.7.0
pure-eval                 0.2.2
pyannote.audio            2.1.1
pyannote.core             4.5
pyannote.database         4.1.3
pyannote.metrics          3.2.1
pyannote.pipeline         2.3
pyasn1                    0.4.8
pyasn1-modules            0.2.8
pybind11                  2.10.4
pycosat                   0.6.4
pycparser                 2.21
pyDeprecate               0.3.2
Pygments                  2.14.0
pyOpenSSL                 22.0.0
pyparsing                 3.0.9
pyreadline                2.1
pyroomacoustics           0.7.3
PySocks                   1.7.1
pystoi                    0.3.3
pytest                    7.3.1
python-dateutil           2.8.2
pytorch-lightning         1.6.5
pytorch-metric-learning   1.7.3
pytorch-ranger            0.1.1
pytz                      2022.7.1
PyYAML                    6.0
pyzmq                     25.0.0
requests                  2.28.1
requests-oauthlib         1.3.1
resampy                   0.4.2
rich                      12.6.0
rsa                       4.9
ruamel.yaml               0.17.21
ruamel.yaml.clib          0.2.7
scikit-learn              1.2.1
scipy                     1.10.1
semver                    2.13.0
sentencepiece             0.1.97
setuptools                65.6.3
shellingham               1.5.0.post1
simplejson                3.18.3
singledispatchmethod      1.0
six                       1.16.0
sortedcontainers          2.4.0
SoundFile                 0.10.3.post1
speechbrain               0.5.13
SQLAlchemy                2.0.4
stack-data                0.6.2
sympy                     1.11.1
tabulate                  0.9.0
tensorboard               2.12.0
tensorboard-data-server   0.7.0
tensorboard-plugin-wit    1.8.1
threadpoolctl             3.1.0
toml                      0.10.2
tomli                     2.0.1
toolz                     0.12.0
torch                     1.11.0
torch-audiomentations     0.11.0
torch-optimizer           0.1.0
torch-pitch-shift         1.2.2
torch-stoi                0.1.2
torch-tb-profiler         0.4.1
torchaudio                0.11.0
torchinfo                 1.7.2
torchmetrics              0.7.3
torchsummary              1.5.1
torchvision               0.12.0
tornado                   6.2
tqdm                      4.64.1
traitlets                 5.9.0
typer                     0.7.0
typing_extensions         4.4.0
urllib3                   1.26.14
wcwidth                   0.2.6
Werkzeug                  2.2.3
wheel                     0.38.4
yarl                      1.8.2
zipp                      3.14.0
zstandard                 0.19.0
mcw519 commented 1 year ago

Hi zuowanbushiwo,

It looks your signal already saturated. can you provide your inference configuration and codes? I get the results here eval_audio.zip. That's totally different.

Thanks.

zuowanbushiwo commented 1 year ago

Hi Wu Thanks,I don't have build and install puresound, I created a new test.py in the root directory of the project, the content is as follows, the init_model function is copied from egs/ns/model.py. I have carefully looked at the eval process of egs/ns/main.py. No files in the project have been modified.

import torch
import torch.nn as nn
from puresound.streaming.skim_inference import StreamingSkiM
from puresound.nnet.lobe.trivial import FiLM
from puresound.nnet.skim import MemLSTM, SegLSTM
from typing import  Optional, Tuple
from puresound.nnet.dparn import DPARN
from puresound.nnet.dpcrn import DPCRN
from puresound.nnet.base_nn import SoTaskWrapModule
from puresound.nnet.lobe.encoder import ConvEncDec
import soundfile 
import os
from puresound.src.audio import AudioIO

# Models
def init_model(name: str, sig_loss: Optional[nn.Module] = None, **kwargs):
    if name == "ns_dpcrn_v0_causal":
        """
        Total params: 1,380,043
        Lookahead(samples): 384
        Receptive Fields(samples): infinite
        """
        model = SoTaskWrapModule(
            encoder=ConvEncDec(
                fft_length=512,
                win_type="hann",
                win_length=512,
                hop_length=128,
                trainable=True,
                output_format="Complex",
            ),
            masker=DPCRN(
                input_type="RI",
                input_dim=512,
                activation_type="PReLU",
                norm_type="bN2d",
                dropout=0.1,
                channels=(1, 32, 32, 32, 64, 128),
                transpose_t_size=2,
                transpose_delay=False,
                skip_conv=False,
                kernel_t=(2, 2, 2, 2, 2),
                kernel_f=(5, 3, 3, 3, 3),
                stride_t=(1, 1, 1, 1, 1),
                stride_f=(2, 2, 1, 1, 1),
                dilation_t=(1, 1, 1, 1, 1),
                dilation_f=(1, 1, 1, 1, 1),
                delay=(0, 0, 0, 0, 0),
                rnn_hidden=128,
            ),
            speaker_net=None,
            loss_func_wav=sig_loss,
            loss_func_spk=None,
            drop_first_bin=True,
            mask_constraint="linear",
            f_type="Complex",
            mask_type="Complex",
            **kwargs
        )

    elif name == "ns_dpcrn_v0":
        """
        Total params: 1,380,043
        Lookahead(samples): 1024; (384+128*(6-1)); semi-causal
        Receptive Fields(samples): infinite
        """
        model = SoTaskWrapModule(
            encoder=ConvEncDec(
                fft_length=512,
                win_type="hann",
                win_length=512,
                hop_length=128,
                trainable=True,
                output_format="Complex",
            ),
            masker=DPCRN(
                input_type="RI",
                input_dim=512,
                activation_type="PReLU",
                norm_type="bN2d",
                dropout=0.1,
                channels=(1, 32, 32, 32, 64, 128),
                transpose_t_size=2,
                transpose_delay=True,
                skip_conv=False,
                kernel_t=(2, 2, 2, 2, 2),
                kernel_f=(5, 3, 3, 3, 3),
                stride_t=(1, 1, 1, 1, 1),
                stride_f=(2, 2, 1, 1, 1),
                dilation_t=(1, 1, 1, 1, 1),
                dilation_f=(1, 1, 1, 1, 1),
                delay=(0, 0, 0, 0, 0),
                rnn_hidden=128,
            ),
            speaker_net=None,
            loss_func_wav=sig_loss,
            loss_func_spk=None,
            drop_first_bin=True,
            mask_constraint="linear",
            f_type="Complex",
            mask_type="Complex",
            **kwargs
        )

    elif name == "ns_dparn_v0_causal":
        """
        Total params: 1,215,179
        Lookahead(samples): 384
        Receptive Fields(samples): infinite
        """
        model = SoTaskWrapModule(
            encoder=ConvEncDec(
                fft_length=512,
                win_type="hann",
                win_length=512,
                hop_length=128,
                trainable=True,
                output_format="Complex",
            ),
            masker=DPARN(
                input_type="RI",
                input_dim=512,
                activation_type="PReLU",
                norm_type="bN2d",
                dropout=0.1,
                channels=(1, 32, 32, 32, 64, 128),
                transpose_t_size=2,
                transpose_delay=False,
                skip_conv=False,
                kernel_t=(2, 2, 2, 2, 2),
                kernel_f=(5, 3, 3, 3, 3),
                stride_t=(1, 1, 1, 1, 1),
                stride_f=(2, 2, 1, 1, 1),
                dilation_t=(1, 1, 1, 1, 1),
                dilation_f=(1, 1, 1, 1, 1),
                delay=(0, 0, 0, 0, 0),
                rnn_hidden=128,
                nhead=8,
            ),
            speaker_net=None,
            loss_func_wav=sig_loss,
            loss_func_spk=None,
            drop_first_bin=True,
            mask_constraint="linear",
            f_type="Complex",
            mask_type="Complex",
            **kwargs
        )

    elif name == "ns_dparn_v0":
        """
        Total params: 1,215,179
        Lookahead(samples): 1024; (384+128*(6-1)); semi-causal
        Receptive Fields(samples): infinite
        """
        model = SoTaskWrapModule(
            encoder=ConvEncDec(
                fft_length=512,
                win_type="hann",
                win_length=512,
                hop_length=128,
                trainable=True,
                output_format="Complex",
            ),
            masker=DPARN(
                input_type="RI",
                input_dim=512,
                activation_type="PReLU",
                norm_type="bN2d",
                dropout=0.1,
                channels=(1, 32, 32, 32, 64, 128),
                transpose_t_size=2,
                transpose_delay=True,
                skip_conv=False,
                kernel_t=(2, 2, 2, 2, 2),
                kernel_f=(5, 3, 3, 3, 3),
                stride_t=(1, 1, 1, 1, 1),
                stride_f=(2, 2, 1, 1, 1),
                dilation_t=(1, 1, 1, 1, 1),
                dilation_f=(1, 1, 1, 1, 1),
                delay=(0, 0, 0, 0, 0),
                rnn_hidden=128,
                nhead=8,
            ),
            speaker_net=None,
            loss_func_wav=sig_loss,
            loss_func_spk=None,
            drop_first_bin=True,
            mask_constraint="linear",
            f_type="Complex",
            mask_type="Complex",
            **kwargs
        )

    else:
        raise NameError

    return model

def test_ns():
    noisy_wav_path = 'test_ns.wav'
    current_directory = os.path.dirname(os.path.abspath(__file__))
    model_path = os.path.join(current_directory,'egs','ns','pretrained')
    config_model_dict = {'ns_dparn_v0_causal':os.path.join(model_path,'vctk_dparn_1a.ckpt'),
                         'ns_dparn_v0':os.path.join(model_path,'vctk_dparn_1b.ckpt'),
                         'ns_dpcrn_v0_causal':os.path.join(model_path,'vctk_dpcrn_1a.ckpt'),
                         'ns_dpcrn_v0':os.path.join(model_path,'vctk_dpcrn_1b.ckpt'),
                         }

    model_select = 'ns_dparn_v0'
    ckpt_path = config_model_dict[model_select]
    ckpt = torch.load(ckpt_path, map_location="cpu")
    model = init_model(model_select, verbose=False)
    model.load_state_dict(ckpt, strict=False)
    model.eval()
    noisy_wav, wav_sr = AudioIO.open(f_path=noisy_wav_path)
    enh_wav = model.inference(noisy_wav)
    enh_wav = enh_wav.detach().cpu()
    if enh_wav.dim() == 3:
        enh_wav = enh_wav.squeeze(0)
    out_name = model_select + '_enh_' + noisy_wav_path    
    AudioIO.save(
        enh_wav,
        out_name,
        wav_sr,
    )

if __name__ == '__main__':
    test_ns()
    print('Hello world')
mcw519 commented 1 year ago

Hi,

I checked your codes and found you missed something so your model's parameters not correctly loaded. You should modify your codes like:

........
model = init_model(model_select, verbose=False)
model.load_state_dict(ckpt["state_dict"], strict=False)
model.eval()

Thanks.

zuowanbushiwo commented 1 year ago

Hi Wu Thank you very much, the problem has been solved. Also sorry for making such a low level mistake and wasting some of your precious time.