Open panhu opened 2 years ago
initialize the weights of convolutional layers by the basis function of the FFT.
Thanks,This is modified code:
import os import tensorflow as tf import tensorflow.keras as keras from tensorflow.keras.models import Model from tensorflow.keras.layers import Lambda, Input,Conv1D, Conv2D, BatchNormalization, Conv2DTranspose, Concatenate, LayerNormalization, PReLU from tensorflow.keras.callbacks import ReduceLROnPlateau, CSVLogger, EarlyStopping, ModelCheckpoint
import soundfile as sf import librosa from random import seed import numpy as np import tqdm from scipy.signal import get_window
from modules import DprnnBlock from utils import reshape, transpose, ParallelModelCheckpoints from data_loader import *
seed(42)
def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): if win_type == 'None' or win_type is None: window = np.ones(win_len) else: window = get_window(win_type, win_len, fftbins=True)**0.5
N = fft_len
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
real_kernel = np.real(fourier_basis)
imag_kernel = np.imag(fourier_basis)
kernel = np.concatenate([real_kernel, imag_kernel], 1).T
if invers :
kernel = np.linalg.pinv(kernel).T
kernel = kernel*window
kernel = kernel[:, None, :]
return tf.convert_to_tensor(kernel,dtype=tf.float32),tf.convert_to_tensor(window[None,:,None],dtype=tf.float32)
class ConvSTFT(tf.keras.layers.Layer): def init(self,win_len,win_inc,fft_len = None,win_type='hamming',feature_type='real',fix=True): super(ConvSTFT,self).init()
if fft_len == None:
self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
else:
self.fft_len = fft_len
kernel,_ = init_kernels(win_len,win_inc,self.fft_len,win_type)
self.weight = tf.reshape(kernel,[400,1,402])
self.feature_type = feature_type
self.stride = win_inc
self.win_len = win_len
self.dim = self.fft_len
def call(self,inputs):
outputs = tf.nn.conv1d(inputs,self.weight,stride= self.stride,padding='VALID')
if self.feature_type == 'complex':
return outputs
else:
#outputs = tf.reshape(outputs,[1,1,-1])
dim = self.dim//2 + 1
real = outputs[:,:,:dim]
imag = outputs[:,:,dim:]
return real,imag
class ConviSTFT(tf.keras.layers.Layer): def init(self,win_len,win_inc,fft_len=None,win_type='hamming',feature_type='real',fix=True): super(ConviSTFT,self).init() if fft_len == None: self.fft_len = np.int(2**np.ceil(np.log2(win_len))) else: self.fft_len = fft_len
kernel,window = init_kernels(win_len,win_inc,self.fft_len,win_type,invers= True)
self.weight = tf.Variable(kernel,trainable=False)
self.weight = tf.reshape(self.weight,[400,1,402])
self.feature_type = feature_type
self.win_type = win_type
self.win_len = win_len
self.win_inc = win_inc
self.stride = win_inc
self.dim = self.fft_len
def call(self,inputs):
outputs = tf.nn.conv1d_transpose(inputs,filters=self.weight,output_shape=([8,1599,400]),strides=1,padding='VALID')
#outputs = tf.reshape(outputs,[8,1,-1])
return outputs
#t = tf.tile(self.window,[1,1,25597])**2
#t = to_float(t)
#t = tf.reshape(t,[1,25597,400])
#self.enframe = tf.reshape(self.enframe,[400,1,400])
class MK_M(tf.keras.layers.Layer): def init(self,kwargs): super(MK_M,self).init(kwargs)
def call(self,inputs):
[noisy_real,noisy_imag,mask] = inputs
noisy_real = noisy_real[:,:,:,0]
noisy_imag = noisy_imag[:,:,:,0]
mask_real = mask[:,:,:,0]
mask_imag = mask[:,:,:,1]
enh_real = noisy_real*mask_real - noisy_imag*mask_imag
enh_imag = noisy_real * mask_imag + noisy_imag*mask_real
return [enh_real,enh_imag]
class Overlap_addLayer(tf.keras.layers.Layer): def init(self,kwargs): super(Overlap_addLayer,self).init(kwargs)
def call(self,inputs):
return tf.signal.overlap_and_add(inputs,200)
class DPCRN_model(): ''' Class to create and train the DPCRN model '''
def __init__(self, batch_size = 1,
length_in_s = 5,
fs = 16000,
norm = 'iLN',
numUnits = 128,
numDP = 2,
block_len = 400,
block_shift = 200,
max_epochs = 200,
lr = 1e-3):
# defining default cost function
self.cost_function = self.snr_cost
self.model = None
# defining default parameters
self.fs = fs
self.length_in_s = length_in_s
self.batch_size = batch_size
# number of the hidden layer size in the LSTM
self.numUnits = numUnits
# number of the DPRNN modules
self.numDP = numDP
# frame length and hop length in STFT
self.block_len = block_len
self.block_shift = block_shift
self.lr = lr
self.max_epochs = max_epochs
# window for STFT: sine win
win = np.sin(np.arange(.5,self.block_len-.5+1)/self.block_len*np.pi)
#print(win)
self.win = tf.constant(win,dtype = 'float32')
self.L = (16000*length_in_s-self.block_len)//self.block_shift + 1
self.multi_gpu = False
# iLN for instant Layer norm and BN for Batch norm
self.input_norm = norm
@staticmethod
def snr_cost(s_estimate, s_true):
'''
Static Method defining the cost function.
The negative signal to noise ratio is calculated here. The loss is
always calculated over the last dimension.
'''
# calculating the SNR
snr = tf.reduce_mean(tf.math.square(s_true), axis=-1, keepdims=True) / \
(tf.reduce_mean(tf.math.square(s_true-s_estimate), axis=-1, keepdims=True) + 1e-8)
# using some more lines, because TF has no log10
num = tf.math.log(snr + 1e-8)
denom = tf.math.log(tf.constant(10, dtype=num.dtype))
loss = -10*(num / (denom))
return loss
@staticmethod
def sisnr_cost(s_hat, s):
'''
Static Method defining the cost function.
The negative signal to noise ratio is calculated here. The loss is
always calculated over the last dimension.
'''
def norm(x):
return tf.reduce_sum(x**2, axis=-1, keepdims=True)
s_target = tf.reduce_sum(
s_hat * s, axis=-1, keepdims=True) * s / norm(s)
upp = norm(s_target)
low = norm(s_hat - s_target)
return -10 * tf.math.log(upp /low) / tf.math.log(10.0)
def spectrum_loss(self,y_true):
'''
spectrum MSE loss
'''
enh_real = self.enh_real
enh_imag = self.enh_imag
enh_mag = tf.sqrt(enh_real**2 + enh_imag**2 + 1e-8)
true_real,true_imag = self.stftLayer(y_true, mode='real_imag')
true_mag = tf.sqrt(true_real**2 + true_imag**2 + 1e-8)
loss_real = tf.reduce_mean((enh_real - true_real)**2,)
loss_imag = tf.reduce_mean((enh_imag - true_imag)**2,)
loss_mag = tf.reduce_mean((enh_mag - true_mag)**2,)
return loss_real + loss_imag + loss_mag
def spectrum_loss_phasen(self, s_hat,s,gamma = 0.3):
true_real,true_imag = self.stftLayer(s, mode='real_imag')
hat_real,hat_imag = self.stftLayer(s_hat, mode='real_imag')
true_mag = tf.sqrt(true_real**2 + true_imag**2+1e-9)
hat_mag = tf.sqrt(hat_real**2 + hat_imag**2+1e-9)
true_real_cprs = (true_real / true_mag )*true_mag**gamma
true_imag_cprs = (true_imag / true_mag )*true_mag**gamma
hat_real_cprs = (hat_real / hat_mag )* hat_mag**gamma
hat_imag_cprs = (hat_imag / hat_mag )* hat_mag**gamma
loss_mag = tf.reduce_mean((hat_mag**gamma - true_mag**gamma)**2,)
loss_real = tf.reduce_mean((hat_real_cprs - true_real_cprs)**2,)
loss_imag = tf.reduce_mean((hat_imag_cprs - true_imag_cprs)**2,)
return 0.7 * loss_mag + 0.3 * ( loss_imag + loss_real )
def lossWrapper(self):
'''
A wrapper function which returns the loss function. This is done to
to enable additional arguments to the loss function if necessary.
'''
def lossFunction(y_true,y_pred):
# calculating loss and squeezing single dimensions away
loss = tf.squeeze(self.cost_function(y_pred,y_true))
mag_loss = tf.math.log(self.spectrum_loss(y_true) + 1e-8)
# calculate mean over batches
loss = tf.reduce_mean(loss)
return loss + mag_loss
return lossFunction
'''
In the following some helper layers are defined.
'''
def seg2frame(self, x):
'''
split signal x to frames
'''
frames = tf.signal.frame(x, self.block_len, self.block_shift)
if self.win is not None:
frames = self.win*frames
return frames
def stftLayer(self, x, mode ='mag_pha'):
'''
Method for an STFT helper layer used with a Lambda layer
mode: 'mag_pha' return magnitude and phase spectrogram
'real_imag' return real and imaginary parts
'''
# creating frames from the continuous waveform
frames = tf.signal.frame(x, self.block_len, self.block_shift)
if self.win is not None:
frames = self.win*frames
# calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
#print('win.............',type(frames))
#frames = frames.numpy()
#print('win.............',type(frames))
stft_dat = tf.signal.rfft(frames)
stft_dat = tf.convert_to_tensor(stft_dat)
# calculating magnitude and phase from the complex signal
output_list = []
if mode == 'mag_pha':
mag = tf.math.abs(stft_dat)
phase = tf.math.angle(stft_dat)
output_list = [mag, phase]
elif mode == 'real_imag':
real = tf.math.real(stft_dat)
imag = tf.math.imag(stft_dat)
output_list = [real, imag]
# returning magnitude and phase as list
return output_list
def fftLayer(self, x):
'''
Method for an fft helper layer used with a Lambda layer.
The layer calculates the rFFT on the last dimension and returns
the magnitude and phase of the STFT.
'''
# calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
stft_dat = tf.signal.rfft(x)
# calculating magnitude and phase from the complex signal
mag = tf.abs(stft_dat)
phase = tf.math.angle(stft_dat)
# returning magnitude and phase as list
return [mag, phase]
def ifftLayer(self, x,mode = 'mag_pha'):
'''
Method for an inverse FFT layer used with an Lambda layer. This layer
calculates time domain frames from magnitude and phase information.
As input x a list with [mag,phase] is required.
'''
if mode == 'mag_pha':
# calculating the complex representation
s1_stft = (tf.cast(x[0], tf.complex64) *
tf.exp( (1j * tf.cast(x[1], tf.complex64))))
elif mode == 'real_imag':
s1_stft = tf.cast(x[0], tf.complex64) + 1j * tf.cast(x[1], tf.complex64)
# returning the time domain frames
return tf.signal.irfft(s1_stft)
def overlapAddLayer(self, x):
'''
Method for an overlap and add helper layer used with a Lambda layer.
This layer reconstructs the waveform from a framed signal.
'''
# calculating and returning the reconstructed waveform
'''
if self.move_dc:
x = x - tf.expand_dims(tf.reduce_mean(x,axis = -1),2)
'''
return tf.signal.overlap_and_add(x, self.block_shift)
def mk_mask(self,x):
'''
Method for complex ratio mask and add helper layer used with a Lambda layer.
'''
[noisy_real,noisy_imag,mask] = x
noisy_real = noisy_real[:,:,:,0]
noisy_imag = noisy_imag[:,:,:,0]
mask_real = mask[:,:,:,0]
mask_imag = mask[:,:,:,1]
enh_real = noisy_real * mask_real - noisy_imag * mask_imag
enh_imag = noisy_real * mask_imag + noisy_imag * mask_real
return [enh_real,enh_imag]
def build_DPCRN_model(self, name = 'model0'):
# input layer for time signal
time_dat = Input(batch_shape=(8, 320000))
# calculate STFT
time_dat_1 = tf.reshape(time_dat,[8,320000,1])
real,imag = ConvSTFT(400,200,400,win_type='hanning',feature_type='real')(time_dat_1)
print(real.shape)
real = tf.reshape(real,[8,-1,201,1])
imag = tf.reshape(imag,[8,-1,201,1])
input_complex_spec = Concatenate(axis = -1)([real,imag])
'''encoder'''
#print(input_complex_spec.shape)
if self.input_norm == 'iLN':
input_complex_spec = LayerNormalization(axis = [-1,-2], name = 'input_norm')(input_complex_spec)
elif self.input_norm == 'BN':
input_complex_spec =BatchNormalization(name = 'input_norm')(input_complex_spec)
# causal padding [1,0],[0,2]
input_complex_spec = tf.pad(input_complex_spec,[[0,0],[1,0],[0,2],[0,0]])
conv_1 = Conv2D(32, (2,5),(1,2),name = name+'_conv_1',padding = "VALID")(input_complex_spec)
bn_1 = BatchNormalization(name = name+'_bn_1')(conv_1)
out_1 = PReLU(shared_axes=[1,2])(bn_1)
# causal padding [1,0],[0,1]
out_1_1 = tf.pad(out_1,[[0,0],[1,0],[0,1],[0,0]])
conv_2 = Conv2D(32, (2,3),(1,2),name = name+'_conv_2',padding = "VALID")(out_1_1)
bn_2 = BatchNormalization(name = name+'_bn_2')(conv_2)
out_2 = PReLU(shared_axes=[1,2])(bn_2)
# causal padding [1,0],[1,1]
out_2_1 = tf.pad(out_2,[[0,0],[1,0],[1,1],[0,0]])
conv_3 = Conv2D(32, (2,3),(1,1),name = name+'_conv_3',padding = "VALID")(out_2_1)
bn_3 = BatchNormalization(name = name+'_bn_3')(conv_3)
out_3 = PReLU(shared_axes=[1,2])(bn_3)
# causal padding [1,0],[1,1]
out_3_1 = tf.pad(out_3,[[0,0],[1,0],[1,1],[0,0]])
conv_4 = Conv2D(64, (2,3),(1,1),name = name+'_conv_4',padding = "VALID")(out_3_1)
bn_4 = BatchNormalization(name = name+'_bn_4')(conv_4)
out_4 = PReLU(shared_axes=[1,2])(bn_4)
# causal padding [1,0],[1,1]
out_4_1 = tf.pad(out_4,[[0,0],[1,0],[1,1],[0,0]])
conv_5 = Conv2D(128, (2,3),(1,1),name = name+'_conv_5',padding = "VALID")(out_4_1)
bn_5 = BatchNormalization(name = name +'_bn_5')(conv_5)
out_5 = PReLU(shared_axes=[1,2])(bn_5)
dp_in = out_5
print(dp_in.shape)
for i in range(self.numDP):
dp_in = DprnnBlock(numUnits = self.numUnits, batch_size = self.batch_size, L = -1,width = 50,channel = 128, causal=True)(dp_in)#self.DPRNN_kernal(dp_in,str(i),last_dp = 0)
dp_out = dp_in
'''decoder'''
skipcon_1 = Concatenate(axis = -1)([out_5,dp_out])
deconv_1 = Conv2DTranspose(64,(2,3),(1,1),name = name+'_dconv_1',padding = 'same')(skipcon_1)
dbn_1 = BatchNormalization(name = name+'_dbn_1')(deconv_1)
dout_1 = PReLU(shared_axes=[1,2])(dbn_1)
skipcon_2 = Concatenate(axis = -1)([out_4,dout_1])
deconv_2 = Conv2DTranspose(32,(2,3),(1,1),name = name+'_dconv_2',padding = 'same')(skipcon_2)
dbn_2 = BatchNormalization(name = name+'_dbn_2')(deconv_2)
dout_2 = PReLU(shared_axes=[1,2])(dbn_2)
skipcon_3 = Concatenate(axis = -1)([out_3,dout_2])
deconv_3 = Conv2DTranspose(32,(2,3),(1,1),name = name+'_dconv_3',padding = 'same')(skipcon_3)
dbn_3 = BatchNormalization(name = name+'_dbn_3')(deconv_3)
dout_3 = PReLU(shared_axes=[1,2])(dbn_3)
skipcon_4 = Concatenate(axis = -1)([out_2,dout_3])
deconv_4 = Conv2DTranspose(32,(2,3),(1,2),name = name+'_dconv_4',padding = 'same')(skipcon_4)
dbn_4 = BatchNormalization(name = name+'_dbn_4')(deconv_4)
dout_4 = PReLU(shared_axes=[1,2])(dbn_4)
skipcon_5 = Concatenate(axis = -1)([out_1,dout_4])
deconv_5 = Conv2DTranspose(2,(2,5),(1,2),name = name+'_dconv_5',padding = 'valid')(skipcon_5)
'''no activation'''
deconv_5 = deconv_5[:,:-1,:-2]
#output_mask = Activation('tanh')(dbn_5)
output_mask = deconv_5
#enh_spec = Lambda(self.mk_mask)([real,imag,output_mask])
enh_spec = MK_M(name='mask')([real,imag,output_mask])
self.enh_real, self.enh_imag = enh_spec[0],enh_spec[1]
#enh_frame = Lambda(self.ifftLayer,arguments = {'mode':'real_imag'})(enh_spec)
#enh_frame = ifft_Layer(name='ifft_layer')(enh_spec)
s1_stft = tf.cast(enh_spec[0], tf.complex64) + 1j * tf.cast(enh_spec[1], tf.complex64)
s1_stft = tf.to_float(s1_stft)
#enh_frame = tf.nn.conv1d_transpose(enh_spec,filters=kernel,output_shape=(8,1599,400),strides=100,padding="VALID")
enh_frame = ConviSTFT(400,100,400,win_type='hanning',feature_type='complex')(s1_stft)
enh_frame = tf.reshape(enh_frame,[8,1599,400])
enh_frame = enh_frame * self.win
enh_time = Overlap_addLayer(name='overlayer')(enh_frame)
self.model = Model(time_dat,enh_time)
self.model.summary()
return self.model
But,i got a error:
ValueError: Depth of output (402) is not a multiple of the number of groups (400) for 'Adam/gradients/convi_stft/conv1d_transpose_grad/Conv2D' (op: 'Conv2D') with input shapes: [8,1,1599,400], [1,400,1,402].
Thanks!
it looks like the output dimensions of the iSTFT do not match the groups number
Yes,but i changed the size of kernel(filter) and stride is invalid。
Can you help me verify the code of tf.nn.conv1d_transpose(ConviSTFT)? Thanks!
Ok,I'll get back to you later. Please post the code of ConviSTFT agian? you can send the .py file to xiaohuaile@smail.nju.edu.cn
OK
Hi: When i "load_model" got a new error,this is my code:
modelparh = r"dpcrn_4.h5" model = tf.keras.models.load_model(modelparh,custom_objects={"DprnnBlock":DprnnBlock,"ConvSTFT":ConvSTFT,"MK_M":MK_M,"ConviSTFT":ConviSTFT, "Overlap_addLayer":Overlap_addLayer})
The error is :
ValueError: Unknown loss function:lossFunction
When i use: model = tf.keras.models.load_model(modelparh,custom_objects={"DprnnBlock":DprnnBlock,"ConvSTFT":ConvSTFT,"MK_M":MK_M,"ConviSTFT":ConviSTFT, "Overlap_addLayer":Overlap_addLayer,"lossFunction":DPCRN_model.lossWrapper})
The error is:
TypeError: lossWrapper() takes 1 positional argument but 2 were given
Hi,i find a method https://github.com/huyanxin/phasen/blob/master/model/conv_stft.py use conv1d and conv1d_transpose instead stft and istft,but it is pytorch.When i replace tensorflow with pytorch,the result is error.Can i know you code about conv1d and conv1d_transpose instead stft and istft? Because later I want to compress it and move it to the chip。 Thank you vary much!