Closed zuowanbushiwo closed 1 year ago
Hi zuowanbushiwo,
It looks your signal already saturated. can you provide your inference configuration and codes? I get the results here eval_audio.zip. That's totally different.
Thanks.
Hi Wu Thanks,I don't have build and install puresound, I created a new test.py in the root directory of the project, the content is as follows, the init_model function is copied from egs/ns/model.py. I have carefully looked at the eval process of egs/ns/main.py. No files in the project have been modified.
import torch
import torch.nn as nn
from puresound.streaming.skim_inference import StreamingSkiM
from puresound.nnet.lobe.trivial import FiLM
from puresound.nnet.skim import MemLSTM, SegLSTM
from typing import Optional, Tuple
from puresound.nnet.dparn import DPARN
from puresound.nnet.dpcrn import DPCRN
from puresound.nnet.base_nn import SoTaskWrapModule
from puresound.nnet.lobe.encoder import ConvEncDec
import soundfile
import os
from puresound.src.audio import AudioIO
# Models
def init_model(name: str, sig_loss: Optional[nn.Module] = None, **kwargs):
if name == "ns_dpcrn_v0_causal":
"""
Total params: 1,380,043
Lookahead(samples): 384
Receptive Fields(samples): infinite
"""
model = SoTaskWrapModule(
encoder=ConvEncDec(
fft_length=512,
win_type="hann",
win_length=512,
hop_length=128,
trainable=True,
output_format="Complex",
),
masker=DPCRN(
input_type="RI",
input_dim=512,
activation_type="PReLU",
norm_type="bN2d",
dropout=0.1,
channels=(1, 32, 32, 32, 64, 128),
transpose_t_size=2,
transpose_delay=False,
skip_conv=False,
kernel_t=(2, 2, 2, 2, 2),
kernel_f=(5, 3, 3, 3, 3),
stride_t=(1, 1, 1, 1, 1),
stride_f=(2, 2, 1, 1, 1),
dilation_t=(1, 1, 1, 1, 1),
dilation_f=(1, 1, 1, 1, 1),
delay=(0, 0, 0, 0, 0),
rnn_hidden=128,
),
speaker_net=None,
loss_func_wav=sig_loss,
loss_func_spk=None,
drop_first_bin=True,
mask_constraint="linear",
f_type="Complex",
mask_type="Complex",
**kwargs
)
elif name == "ns_dpcrn_v0":
"""
Total params: 1,380,043
Lookahead(samples): 1024; (384+128*(6-1)); semi-causal
Receptive Fields(samples): infinite
"""
model = SoTaskWrapModule(
encoder=ConvEncDec(
fft_length=512,
win_type="hann",
win_length=512,
hop_length=128,
trainable=True,
output_format="Complex",
),
masker=DPCRN(
input_type="RI",
input_dim=512,
activation_type="PReLU",
norm_type="bN2d",
dropout=0.1,
channels=(1, 32, 32, 32, 64, 128),
transpose_t_size=2,
transpose_delay=True,
skip_conv=False,
kernel_t=(2, 2, 2, 2, 2),
kernel_f=(5, 3, 3, 3, 3),
stride_t=(1, 1, 1, 1, 1),
stride_f=(2, 2, 1, 1, 1),
dilation_t=(1, 1, 1, 1, 1),
dilation_f=(1, 1, 1, 1, 1),
delay=(0, 0, 0, 0, 0),
rnn_hidden=128,
),
speaker_net=None,
loss_func_wav=sig_loss,
loss_func_spk=None,
drop_first_bin=True,
mask_constraint="linear",
f_type="Complex",
mask_type="Complex",
**kwargs
)
elif name == "ns_dparn_v0_causal":
"""
Total params: 1,215,179
Lookahead(samples): 384
Receptive Fields(samples): infinite
"""
model = SoTaskWrapModule(
encoder=ConvEncDec(
fft_length=512,
win_type="hann",
win_length=512,
hop_length=128,
trainable=True,
output_format="Complex",
),
masker=DPARN(
input_type="RI",
input_dim=512,
activation_type="PReLU",
norm_type="bN2d",
dropout=0.1,
channels=(1, 32, 32, 32, 64, 128),
transpose_t_size=2,
transpose_delay=False,
skip_conv=False,
kernel_t=(2, 2, 2, 2, 2),
kernel_f=(5, 3, 3, 3, 3),
stride_t=(1, 1, 1, 1, 1),
stride_f=(2, 2, 1, 1, 1),
dilation_t=(1, 1, 1, 1, 1),
dilation_f=(1, 1, 1, 1, 1),
delay=(0, 0, 0, 0, 0),
rnn_hidden=128,
nhead=8,
),
speaker_net=None,
loss_func_wav=sig_loss,
loss_func_spk=None,
drop_first_bin=True,
mask_constraint="linear",
f_type="Complex",
mask_type="Complex",
**kwargs
)
elif name == "ns_dparn_v0":
"""
Total params: 1,215,179
Lookahead(samples): 1024; (384+128*(6-1)); semi-causal
Receptive Fields(samples): infinite
"""
model = SoTaskWrapModule(
encoder=ConvEncDec(
fft_length=512,
win_type="hann",
win_length=512,
hop_length=128,
trainable=True,
output_format="Complex",
),
masker=DPARN(
input_type="RI",
input_dim=512,
activation_type="PReLU",
norm_type="bN2d",
dropout=0.1,
channels=(1, 32, 32, 32, 64, 128),
transpose_t_size=2,
transpose_delay=True,
skip_conv=False,
kernel_t=(2, 2, 2, 2, 2),
kernel_f=(5, 3, 3, 3, 3),
stride_t=(1, 1, 1, 1, 1),
stride_f=(2, 2, 1, 1, 1),
dilation_t=(1, 1, 1, 1, 1),
dilation_f=(1, 1, 1, 1, 1),
delay=(0, 0, 0, 0, 0),
rnn_hidden=128,
nhead=8,
),
speaker_net=None,
loss_func_wav=sig_loss,
loss_func_spk=None,
drop_first_bin=True,
mask_constraint="linear",
f_type="Complex",
mask_type="Complex",
**kwargs
)
else:
raise NameError
return model
def test_ns():
noisy_wav_path = 'test_ns.wav'
current_directory = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_directory,'egs','ns','pretrained')
config_model_dict = {'ns_dparn_v0_causal':os.path.join(model_path,'vctk_dparn_1a.ckpt'),
'ns_dparn_v0':os.path.join(model_path,'vctk_dparn_1b.ckpt'),
'ns_dpcrn_v0_causal':os.path.join(model_path,'vctk_dpcrn_1a.ckpt'),
'ns_dpcrn_v0':os.path.join(model_path,'vctk_dpcrn_1b.ckpt'),
}
model_select = 'ns_dparn_v0'
ckpt_path = config_model_dict[model_select]
ckpt = torch.load(ckpt_path, map_location="cpu")
model = init_model(model_select, verbose=False)
model.load_state_dict(ckpt, strict=False)
model.eval()
noisy_wav, wav_sr = AudioIO.open(f_path=noisy_wav_path)
enh_wav = model.inference(noisy_wav)
enh_wav = enh_wav.detach().cpu()
if enh_wav.dim() == 3:
enh_wav = enh_wav.squeeze(0)
out_name = model_select + '_enh_' + noisy_wav_path
AudioIO.save(
enh_wav,
out_name,
wav_sr,
)
if __name__ == '__main__':
test_ns()
print('Hello world')
Hi,
I checked your codes and found you missed something so your model's parameters not correctly loaded. You should modify your codes like:
........
model = init_model(model_select, verbose=False)
model.load_state_dict(ckpt["state_dict"], strict=False)
model.eval()
Thanks.
Hi Wu Thank you very much, the problem has been solved. Also sorry for making such a low level mistake and wasting some of your precious time.
inference results:
ns_data.zip