Le-Xiaohuai-speech / DPCRN_DNS3

Implementation of paper "DPCRN: Dual-Path Convolution Recurrent Network for Single Channel Speech Enhancement"
188 stars 41 forks source link

If i need a file of tflite format,how to convert the stft and istft use conv1d? #30

Open panhu opened 2 years ago

panhu commented 2 years ago

Hi,i find a method https://github.com/huyanxin/phasen/blob/master/model/conv_stft.py use conv1d and conv1d_transpose instead stft and istft,but it is pytorch.When i replace tensorflow with pytorch,the result is error.Can i know you code about conv1d and conv1d_transpose instead stft and istft? Because later I want to compress it and move it to the chip。 Thank you vary much!

Le-Xiaohuai-speech commented 2 years ago

initialize the weights of convolutional layers by the basis function of the FFT.

panhu commented 2 years ago

Thanks,This is modified code:

import os import tensorflow as tf import tensorflow.keras as keras from tensorflow.keras.models import Model from tensorflow.keras.layers import Lambda, Input,Conv1D, Conv2D, BatchNormalization, Conv2DTranspose, Concatenate, LayerNormalization, PReLU from tensorflow.keras.callbacks import ReduceLROnPlateau, CSVLogger, EarlyStopping, ModelCheckpoint

from tensorflow.keras.layers import Conv1DTranspose

import soundfile as sf import librosa from random import seed import numpy as np import tqdm from scipy.signal import get_window

from modules import DprnnBlock from utils import reshape, transpose, ParallelModelCheckpoints from data_loader import *

seed(42)

def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): if win_type == 'None' or win_type is None: window = np.ones(win_len) else: window = get_window(win_type, win_len, fftbins=True)**0.5

N = fft_len
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
real_kernel = np.real(fourier_basis)
imag_kernel = np.imag(fourier_basis)
kernel = np.concatenate([real_kernel, imag_kernel], 1).T

if invers :
    kernel = np.linalg.pinv(kernel).T 

kernel = kernel*window
kernel = kernel[:, None, :]
return tf.convert_to_tensor(kernel,dtype=tf.float32),tf.convert_to_tensor(window[None,:,None],dtype=tf.float32)

kernel = init_kernels(400, 100, 512, win_type='hanning', invers=False)

class ConvSTFT(tf.keras.layers.Layer): def init(self,win_len,win_inc,fft_len = None,win_type='hamming',feature_type='real',fix=True): super(ConvSTFT,self).init()

       if fft_len == None:
          self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
       else:
          self.fft_len = fft_len

       kernel,_ = init_kernels(win_len,win_inc,self.fft_len,win_type)
       self.weight = tf.reshape(kernel,[400,1,402])
       self.feature_type = feature_type
       self.stride = win_inc
       self.win_len = win_len
       self.dim = self.fft_len

   def call(self,inputs):

       outputs = tf.nn.conv1d(inputs,self.weight,stride= self.stride,padding='VALID')

       if self.feature_type == 'complex':
          return outputs
       else:
          #outputs = tf.reshape(outputs,[1,1,-1])
          dim = self.dim//2 + 1
          real = outputs[:,:,:dim]
          imag = outputs[:,:,dim:] 
       return real,imag

class ConviSTFT(tf.keras.layers.Layer): def init(self,win_len,win_inc,fft_len=None,win_type='hamming',feature_type='real',fix=True): super(ConviSTFT,self).init() if fft_len == None: self.fft_len = np.int(2**np.ceil(np.log2(win_len))) else: self.fft_len = fft_len

      kernel,window = init_kernels(win_len,win_inc,self.fft_len,win_type,invers= True)
      self.weight = tf.Variable(kernel,trainable=False)
      self.weight = tf.reshape(self.weight,[400,1,402])
      self.feature_type = feature_type
      self.win_type = win_type
      self.win_len = win_len
      self.win_inc = win_inc
      self.stride = win_inc
      self.dim = self.fft_len

  def call(self,inputs):

      outputs = tf.nn.conv1d_transpose(inputs,filters=self.weight,output_shape=([8,1599,400]),strides=1,padding='VALID')

      #outputs = tf.reshape(outputs,[8,1,-1])

      return outputs
      #t = tf.tile(self.window,[1,1,25597])**2
      #t = to_float(t)
      #t = tf.reshape(t,[1,25597,400])

      #self.enframe = tf.reshape(self.enframe,[400,1,400])

class MK_M(tf.keras.layers.Layer): def init(self,kwargs): super(MK_M,self).init(kwargs)

  def call(self,inputs):
        [noisy_real,noisy_imag,mask] = inputs
        noisy_real = noisy_real[:,:,:,0]
        noisy_imag = noisy_imag[:,:,:,0]            

        mask_real = mask[:,:,:,0]
        mask_imag = mask[:,:,:,1]

        enh_real = noisy_real*mask_real - noisy_imag*mask_imag
        enh_imag = noisy_real * mask_imag + noisy_imag*mask_real

        return [enh_real,enh_imag]

class Overlap_addLayer(tf.keras.layers.Layer): def init(self,kwargs): super(Overlap_addLayer,self).init(kwargs)

  def call(self,inputs):
     return tf.signal.overlap_and_add(inputs,200)

class DPCRN_model(): ''' Class to create and train the DPCRN model '''

def __init__(self, batch_size = 1,
                   length_in_s = 5,
                   fs = 16000,
                   norm = 'iLN',
                   numUnits = 128,
                   numDP = 2,
                   block_len = 400,
                   block_shift = 200,
                   max_epochs = 200,
                   lr = 1e-3):

    # defining default cost function
    self.cost_function = self.snr_cost
    self.model = None
    # defining default parameters
    self.fs = fs
    self.length_in_s = length_in_s
    self.batch_size = batch_size
    # number of the hidden layer size in the LSTM
    self.numUnits = numUnits
    # number of the DPRNN modules
    self.numDP = numDP
    # frame length and hop length in STFT
    self.block_len = block_len
    self.block_shift = block_shift
    self.lr = lr
    self.max_epochs = max_epochs
    # window for STFT: sine win
    win = np.sin(np.arange(.5,self.block_len-.5+1)/self.block_len*np.pi)
    #print(win)
    self.win = tf.constant(win,dtype = 'float32')

    self.L = (16000*length_in_s-self.block_len)//self.block_shift + 1

    self.multi_gpu = False
    # iLN for instant Layer norm and BN for Batch norm
    self.input_norm = norm

@staticmethod
def snr_cost(s_estimate, s_true):
    '''
    Static Method defining the cost function. 
    The negative signal to noise ratio is calculated here. The loss is 
    always calculated over the last dimension. 
    '''
    # calculating the SNR
    snr = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True) / \
        (tf.reduce_mean(tf.math.square(s_true-s_estimate), axis=-1, keepdims=True) + 1e-8)
    # using some more lines, because TF has no log10
    num = tf.math.log(snr + 1e-8) 
    denom = tf.math.log(tf.constant(10, dtype=num.dtype))
    loss = -10*(num / (denom))

    return loss

@staticmethod
def sisnr_cost(s_hat, s):
    '''
    Static Method defining the cost function. 
    The negative signal to noise ratio is calculated here. The loss is 
    always calculated over the last dimension. 
    '''
    def norm(x):
        return tf.reduce_sum(x**2, axis=-1, keepdims=True)

    s_target = tf.reduce_sum(
        s_hat * s, axis=-1, keepdims=True) * s / norm(s)
    upp = norm(s_target)
    low = norm(s_hat - s_target)

    return -10 * tf.math.log(upp /low) / tf.math.log(10.0)  

def spectrum_loss(self,y_true):
    '''
    spectrum MSE loss 
    '''
    enh_real = self.enh_real
    enh_imag = self.enh_imag
    enh_mag = tf.sqrt(enh_real**2 + enh_imag**2 + 1e-8)

    true_real,true_imag = self.stftLayer(y_true, mode='real_imag')
    true_mag = tf.sqrt(true_real**2 + true_imag**2 + 1e-8)

    loss_real = tf.reduce_mean((enh_real - true_real)**2,)
    loss_imag = tf.reduce_mean((enh_imag - true_imag)**2,)
    loss_mag = tf.reduce_mean((enh_mag - true_mag)**2,) 

    return loss_real + loss_imag + loss_mag

def spectrum_loss_phasen(self, s_hat,s,gamma = 0.3):

    true_real,true_imag = self.stftLayer(s, mode='real_imag')
    hat_real,hat_imag = self.stftLayer(s_hat, mode='real_imag')

    true_mag = tf.sqrt(true_real**2 + true_imag**2+1e-9)
    hat_mag = tf.sqrt(hat_real**2 + hat_imag**2+1e-9)

    true_real_cprs = (true_real / true_mag )*true_mag**gamma
    true_imag_cprs = (true_imag / true_mag )*true_mag**gamma
    hat_real_cprs = (hat_real / hat_mag )* hat_mag**gamma
    hat_imag_cprs = (hat_imag / hat_mag )* hat_mag**gamma

    loss_mag = tf.reduce_mean((hat_mag**gamma - true_mag**gamma)**2,)         
    loss_real = tf.reduce_mean((hat_real_cprs - true_real_cprs)**2,)
    loss_imag = tf.reduce_mean((hat_imag_cprs - true_imag_cprs)**2,)

    return 0.7 * loss_mag + 0.3 * ( loss_imag + loss_real ) 

def lossWrapper(self):
    '''
    A wrapper function which returns the loss function. This is done to
    to enable additional arguments to the loss function if necessary.
    '''
    def lossFunction(y_true,y_pred):
        # calculating loss and squeezing single dimensions away
        loss = tf.squeeze(self.cost_function(y_pred,y_true)) 
        mag_loss = tf.math.log(self.spectrum_loss(y_true) + 1e-8)
        # calculate mean over batches
        loss = tf.reduce_mean(loss)
        return loss + mag_loss 

    return lossFunction

'''
In the following some helper layers are defined.
'''  
def seg2frame(self, x):
    '''
    split signal x to frames
    '''
    frames = tf.signal.frame(x, self.block_len, self.block_shift)
    if self.win is not None:
        frames = self.win*frames
    return frames

def stftLayer(self, x, mode ='mag_pha'):
    '''
    Method for an STFT helper layer used with a Lambda layer
    mode: 'mag_pha'   return magnitude and phase spectrogram
          'real_imag' return real and imaginary parts
    '''
    # creating frames from the continuous waveform
    frames = tf.signal.frame(x, self.block_len, self.block_shift)

    if self.win is not None:
        frames = self.win*frames
    # calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
    #print('win.............',type(frames))
    #frames = frames.numpy()
    #print('win.............',type(frames))
    stft_dat = tf.signal.rfft(frames)
    stft_dat = tf.convert_to_tensor(stft_dat)
    # calculating magnitude and phase from the complex signal
    output_list = []
    if mode == 'mag_pha':
        mag = tf.math.abs(stft_dat)
        phase = tf.math.angle(stft_dat)
        output_list = [mag, phase]
    elif mode == 'real_imag':
        real = tf.math.real(stft_dat)
        imag = tf.math.imag(stft_dat)
        output_list = [real, imag]            
    # returning magnitude and phase as list
    return output_list

def fftLayer(self, x):
    '''
    Method for an fft helper layer used with a Lambda layer.
    The layer calculates the rFFT on the last dimension and returns
    the magnitude and phase of the STFT.
    '''
    # calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
    stft_dat = tf.signal.rfft(x)
    # calculating magnitude and phase from the complex signal
    mag = tf.abs(stft_dat)
    phase = tf.math.angle(stft_dat)
    # returning magnitude and phase as list
    return [mag, phase]

def ifftLayer(self, x,mode = 'mag_pha'):
    '''
    Method for an inverse FFT layer used with an Lambda layer. This layer
    calculates time domain frames from magnitude and phase information. 
    As input x a list with [mag,phase] is required.
    '''
    if mode == 'mag_pha':
    # calculating the complex representation
        s1_stft = (tf.cast(x[0], tf.complex64) * 
                    tf.exp( (1j * tf.cast(x[1], tf.complex64))))
    elif mode == 'real_imag':
        s1_stft = tf.cast(x[0], tf.complex64) + 1j * tf.cast(x[1], tf.complex64)
    # returning the time domain frames
    return tf.signal.irfft(s1_stft)  

def overlapAddLayer(self, x):
    '''
    Method for an overlap and add helper layer used with a Lambda layer.
    This layer reconstructs the waveform from a framed signal.
    '''
    # calculating and returning the reconstructed waveform
    '''
    if self.move_dc:
        x = x - tf.expand_dims(tf.reduce_mean(x,axis = -1),2)
    '''
    return tf.signal.overlap_and_add(x, self.block_shift)              

def mk_mask(self,x):
    '''
    Method for complex ratio mask and add helper layer used with a Lambda layer.
    '''
    [noisy_real,noisy_imag,mask] = x
    noisy_real = noisy_real[:,:,:,0]
    noisy_imag = noisy_imag[:,:,:,0]

    mask_real = mask[:,:,:,0]
    mask_imag = mask[:,:,:,1]

    enh_real = noisy_real * mask_real - noisy_imag * mask_imag
    enh_imag = noisy_real * mask_imag + noisy_imag * mask_real

    return [enh_real,enh_imag]

def build_DPCRN_model(self, name = 'model0'):

    # input layer for time signal
    time_dat = Input(batch_shape=(8, 320000))
    # calculate STFT

    time_dat_1 = tf.reshape(time_dat,[8,320000,1])
    real,imag = ConvSTFT(400,200,400,win_type='hanning',feature_type='real')(time_dat_1)
    print(real.shape)

    real = tf.reshape(real,[8,-1,201,1])
    imag = tf.reshape(imag,[8,-1,201,1])

    input_complex_spec = Concatenate(axis = -1)([real,imag])
    '''encoder'''
    #print(input_complex_spec.shape)
    if self.input_norm == 'iLN':    
        input_complex_spec = LayerNormalization(axis = [-1,-2], name = 'input_norm')(input_complex_spec)
    elif self.input_norm == 'BN':    
        input_complex_spec =BatchNormalization(name = 'input_norm')(input_complex_spec)

    # causal padding [1,0],[0,2]
    input_complex_spec = tf.pad(input_complex_spec,[[0,0],[1,0],[0,2],[0,0]])
    conv_1 = Conv2D(32, (2,5),(1,2),name = name+'_conv_1',padding = "VALID")(input_complex_spec)
    bn_1 = BatchNormalization(name = name+'_bn_1')(conv_1)
    out_1 = PReLU(shared_axes=[1,2])(bn_1)
    # causal padding [1,0],[0,1]
    out_1_1 = tf.pad(out_1,[[0,0],[1,0],[0,1],[0,0]])
    conv_2 = Conv2D(32, (2,3),(1,2),name = name+'_conv_2',padding = "VALID")(out_1_1)
    bn_2 = BatchNormalization(name = name+'_bn_2')(conv_2)
    out_2 = PReLU(shared_axes=[1,2])(bn_2)
    # causal padding [1,0],[1,1]
    out_2_1 = tf.pad(out_2,[[0,0],[1,0],[1,1],[0,0]])
    conv_3 = Conv2D(32, (2,3),(1,1),name = name+'_conv_3',padding = "VALID")(out_2_1)
    bn_3 = BatchNormalization(name = name+'_bn_3')(conv_3)
    out_3 = PReLU(shared_axes=[1,2])(bn_3)
    # causal padding [1,0],[1,1]
    out_3_1 = tf.pad(out_3,[[0,0],[1,0],[1,1],[0,0]])
    conv_4 = Conv2D(64, (2,3),(1,1),name = name+'_conv_4',padding = "VALID")(out_3_1)
    bn_4 = BatchNormalization(name = name+'_bn_4')(conv_4)
    out_4 = PReLU(shared_axes=[1,2])(bn_4)
    # causal padding [1,0],[1,1]
    out_4_1 = tf.pad(out_4,[[0,0],[1,0],[1,1],[0,0]])
    conv_5 = Conv2D(128, (2,3),(1,1),name = name+'_conv_5',padding = "VALID")(out_4_1)
    bn_5 = BatchNormalization(name = name +'_bn_5')(conv_5)
    out_5 = PReLU(shared_axes=[1,2])(bn_5)

    dp_in = out_5

    print(dp_in.shape)
    for i in range(self.numDP):

        dp_in = DprnnBlock(numUnits = self.numUnits, batch_size = self.batch_size, L = -1,width = 50,channel = 128, causal=True)(dp_in)#self.DPRNN_kernal(dp_in,str(i),last_dp = 0)

    dp_out = dp_in

    '''decoder'''
    skipcon_1 = Concatenate(axis = -1)([out_5,dp_out])

    deconv_1 = Conv2DTranspose(64,(2,3),(1,1),name = name+'_dconv_1',padding = 'same')(skipcon_1)
    dbn_1 = BatchNormalization(name = name+'_dbn_1')(deconv_1)
    dout_1 = PReLU(shared_axes=[1,2])(dbn_1)

    skipcon_2 = Concatenate(axis = -1)([out_4,dout_1])

    deconv_2 = Conv2DTranspose(32,(2,3),(1,1),name = name+'_dconv_2',padding = 'same')(skipcon_2)
    dbn_2 = BatchNormalization(name = name+'_dbn_2')(deconv_2)
    dout_2 = PReLU(shared_axes=[1,2])(dbn_2)

    skipcon_3 = Concatenate(axis = -1)([out_3,dout_2])

    deconv_3 = Conv2DTranspose(32,(2,3),(1,1),name = name+'_dconv_3',padding = 'same')(skipcon_3)
    dbn_3 = BatchNormalization(name = name+'_dbn_3')(deconv_3)
    dout_3 = PReLU(shared_axes=[1,2])(dbn_3)

    skipcon_4 = Concatenate(axis = -1)([out_2,dout_3])

    deconv_4 = Conv2DTranspose(32,(2,3),(1,2),name = name+'_dconv_4',padding = 'same')(skipcon_4)
    dbn_4 = BatchNormalization(name = name+'_dbn_4')(deconv_4)
    dout_4 = PReLU(shared_axes=[1,2])(dbn_4)

    skipcon_5 = Concatenate(axis = -1)([out_1,dout_4])

    deconv_5 = Conv2DTranspose(2,(2,5),(1,2),name = name+'_dconv_5',padding = 'valid')(skipcon_5)

    '''no activation'''        
    deconv_5 = deconv_5[:,:-1,:-2]

    #output_mask = Activation('tanh')(dbn_5)
    output_mask = deconv_5

    #enh_spec = Lambda(self.mk_mask)([real,imag,output_mask])
    enh_spec = MK_M(name='mask')([real,imag,output_mask])

    self.enh_real, self.enh_imag = enh_spec[0],enh_spec[1]

    #enh_frame = Lambda(self.ifftLayer,arguments = {'mode':'real_imag'})(enh_spec)
    #enh_frame = ifft_Layer(name='ifft_layer')(enh_spec)

    s1_stft = tf.cast(enh_spec[0], tf.complex64) + 1j * tf.cast(enh_spec[1], tf.complex64)
    s1_stft = tf.to_float(s1_stft)
    #enh_frame = tf.nn.conv1d_transpose(enh_spec,filters=kernel,output_shape=(8,1599,400),strides=100,padding="VALID")
    enh_frame = ConviSTFT(400,100,400,win_type='hanning',feature_type='complex')(s1_stft)
    enh_frame = tf.reshape(enh_frame,[8,1599,400])

    enh_frame = enh_frame * self.win

    enh_time = Overlap_addLayer(name='overlayer')(enh_frame)        

    self.model = Model(time_dat,enh_time)
    self.model.summary()

    return self.model

But,i got a error:

ValueError: Depth of output (402) is not a multiple of the number of groups (400) for 'Adam/gradients/convi_stft/conv1d_transpose_grad/Conv2D' (op: 'Conv2D') with input shapes: [8,1,1599,400], [1,400,1,402].

Thanks!

Le-Xiaohuai-speech commented 2 years ago

it looks like the output dimensions of the iSTFT do not match the groups number

panhu commented 2 years ago

Yes,but i changed the size of kernel(filter) and stride is invalid。

panhu commented 2 years ago

Can you help me verify the code of tf.nn.conv1d_transpose(ConviSTFT)? Thanks!

Le-Xiaohuai-speech commented 2 years ago

Ok,I'll get back to you later. Please post the code of ConviSTFT agian? you can send the .py file to xiaohuaile@smail.nju.edu.cn

panhu commented 2 years ago

OK

panhu commented 2 years ago

Hi: When i "load_model" got a new error,this is my code:

modelparh = r"dpcrn_4.h5" model = tf.keras.models.load_model(modelparh,custom_objects={"DprnnBlock":DprnnBlock,"ConvSTFT":ConvSTFT,"MK_M":MK_M,"ConviSTFT":ConviSTFT, "Overlap_addLayer":Overlap_addLayer})

The error is :

ValueError: Unknown loss function:lossFunction

panhu commented 2 years ago

When i use: model = tf.keras.models.load_model(modelparh,custom_objects={"DprnnBlock":DprnnBlock,"ConvSTFT":ConvSTFT,"MK_M":MK_M,"ConviSTFT":ConviSTFT, "Overlap_addLayer":Overlap_addLayer,"lossFunction":DPCRN_model.lossWrapper})

The error is:

TypeError: lossWrapper() takes 1 positional argument but 2 were given