wangtianrui / DCCRN

implementation of "DCCRN-Deep Complex Convolution Recurrent Network for Phase-Aware Speech Enhancement" by pytorch
49 stars 16 forks source link

Inference code #5

Closed agunapal closed 3 years ago

agunapal commented 3 years ago

Hello, Thank you for sharing your code. Can you please the inference script as well.

wangtianrui commented 3 years ago
def audiowrite(destpath, audio, sample_rate):
    '''Function to write audio'''
    import soundfile as sf
    destpath = os.path.abspath(destpath)
    destdir = os.path.dirname(destpath)

    if not os.path.exists(destdir):
        os.makedirs(destdir)

    sf.write(destpath, audio, sample_rate)
    return

def predict_torchmodel(model, noisy_path, save_path):
    assert os.path.exists(noisy_path), "noisy path error:" + noisy_path
    noisy_wave, frq = sf.read(noisy_path)
    assert frq == 16000, "sample rate must equal 16000"
    with torch.no_grad():
        net_inp = torch.tensor(noisy_wave)[None].to(torch.float32)
        estimate = model.istft(model(net_inp)).squeeze(1).cpu().data.numpy().flatten()
        audiowrite(save_path, estimate, frq)
agunapal commented 3 years ago

Thanks..I get this error. RuntimeError: Expected 3-dimensional input for 3-dimensional weight [514, 1, 400], but got 2-dimensional input of size [1, 4046800] instead line 93, in forward outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)

wangtianrui commented 3 years ago

oh, sorry! To such:

def predict_torchmodel(model, noisy_path, save_path):
    assert os.path.exists(noisy_path), "noisy path error:" + noisy_path
    noisy_wave, frq = sf.read(noisy_path)
    assert frq == 16000, "sample rate must equal 16000"
    with torch.no_grad():
        net_inp = torch.tensor(noisy_wave)[None].to(torch.float32)
        estimate = model(net_inp).squeeze(1).cpu().data.numpy().flatten()
        audiowrite(save_path, estimate, frq)
agunapal commented 3 years ago

Thank you. That worked