breizhn / DTLN

Tensorflow 2.x implementation of the DTLN real time speech denoising model. With TF-lite, ONNX and real-time audio processing support.
MIT License
587 stars 161 forks source link

Train stateful DTLN model #52

Open chixii opened 2 years ago

chixii commented 2 years ago

Hi @breizhn , I'm trying train stateful DTLN model. And I see time_dat = Input(batch_shape=(1, self.blockLen)) . Why is the self.batchsize = 1 ? Is this performance better?

breizhn commented 2 years ago

The stateful model inside the class was only build for exporting the model as a SavedModel.

If you would like to train a stateful model use the standard model and change line 334 to a defined Input length and batch size.

In the next step add in the call to self.seperation_kernel in line 344 and 354 the argument stateful=True.

StuartIanNaylor commented 2 years ago

@breizhn Nils After trainign a model that works great with run_evaluation.py I am running python convert_weights_to_tf_light.py -m /name/of/the/model.h5 \ -t name_target to get the 2 tflite models Then running real_time_processing_tf_lite.py

ValueError: Cannot set tensor: Dimension mismatch. Got 257 but expected 512 for dimension 2 of input 0

If you are using tflite do you have to retrain as a stateful with the above or I am just doing something wrong with convert weights to tf_lite?

https://github.com/breizhn/DTLN/blob/1de1f15a8b5b7e1c44905618ff2ef70ca8277fbc/DTLN_model.py#L344 mask_1 = self.seperation_kernel(self.numLayer, (self.blockLen//2+1), mag_norm, stateful=True)

Is there a preference to defined Input length and batch size for realtime tflite?

breizhn commented 2 years ago

Hi Stuart (@StuartIanNaylor),

ValueError: Cannot set tensor: Dimension mismatch. Got 257 but expected 512 for dimension 2 of input 0 When converting the model to TF-lite, the inputs and outputs have sometimes a random order for some reason. After conversion you have to check/print the input and output details to set the tensors/data to the correct input of the model and get the correct stuff back.

https://github.com/breizhn/DTLN/blob/1de1f15a8b5b7e1c44905618ff2ef70ca8277fbc/real_time_processing_tf_lite.py#L34-L38

And for the predefined length: As far as I know, conversion does not work with undefined dimensions.

StuartIanNaylor commented 2 years ago

I will have a look at the shapes and see if I can work it out as just trying a singlekw of "hey marvin" than a full language just out of interest to see how well it could would work. Seems to work pretty good with the realtime tests but really need to get its as fast and as low load as possible. I had 200k 1 sec heymarvins that I concatenated into 31 sec wavs and did the DNS Challenge aug & splits for 80 hours but was just wondering and will do a really long 400 hour run. Google have targetted speech and it just made me think could you do targetted KW also if the functions are added then I guess it could also have 'on device training' where a small model shifts the weights on captured KW.

StuartIanNaylor commented 2 years ago

Got round to having another go so just posting as I go.

python convert_weights_to_tf_lite.py -m ./models_DTLN_model/DTLN_model.h5                                               -t DTLN_model
2022-05-16 11:24:40.691917: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-16 11:24:40.722002: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-16 11:24:40.722205: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-16 11:24:40.734035: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-16 11:24:40.734894: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-16 11:24:40.735114: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-16 11:24:40.735261: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-16 11:24:41.132514: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-16 11:24:41.132737: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-16 11:24:41.132914: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:936] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-16 11:24:41.133065: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5879 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3050, pci bus id: 0000:01:00.0, compute capability: 8.6
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(1, 512)]           0           []                               

 lambda (Lambda)                [(1, 1, 257),        0           ['input_1[0][0]']                
                                 (1, 1, 257)]                                                     

 lstm (LSTM)                    (1, 1, 128)          197632      ['lambda[0][0]']                 

 dropout (Dropout)              (1, 1, 128)          0           ['lstm[0][0]']                   

 lstm_1 (LSTM)                  (1, 1, 128)          131584      ['dropout[0][0]']                

 dense (Dense)                  (1, 1, 257)          33153       ['lstm_1[0][0]']                 

 activation (Activation)        (1, 1, 257)          0           ['dense[0][0]']                  

 multiply (Multiply)            (1, 1, 257)          0           ['lambda[0][0]',                 
                                                                  'activation[0][0]']             

 lambda_1 (Lambda)              (1, 1, 512)          0           ['multiply[0][0]',               
                                                                  'lambda[0][1]']                 

 conv1d (Conv1D)                (1, 1, 256)          131072      ['lambda_1[0][0]']               

 instant_layer_normalization (I  (1, 1, 256)         512         ['conv1d[0][0]']                 
 nstantLayerNormalization)                                                                        

 lstm_2 (LSTM)                  (1, 1, 128)          197120      ['instant_layer_normalization[0][
                                                                 0]']                             

 dropout_1 (Dropout)            (1, 1, 128)          0           ['lstm_2[0][0]']                 

 lstm_3 (LSTM)                  (1, 1, 128)          131584      ['dropout_1[0][0]']              

 dense_1 (Dense)                (1, 1, 256)          33024       ['lstm_3[0][0]']                 

 activation_1 (Activation)      (1, 1, 256)          0           ['dense_1[0][0]']                

 multiply_1 (Multiply)          (1, 1, 256)          0           ['conv1d[0][0]',                 
                                                                  'activation_1[0][0]']           

 conv1d_1 (Conv1D)              (1, 1, 512)          131072      ['multiply_1[0][0]']             

==================================================================================================
Total params: 986,753
Trainable params: 986,753
Non-trainable params: 0
__________________________________________________________________________________________________
None
WARNING:tensorflow:Layer lstm_4 will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
WARNING:tensorflow:Layer lstm_5 will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
WARNING:tensorflow:Layer lstm_6 will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
WARNING:tensorflow:Layer lstm_7 will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
2022-05-16 11:24:42.280438: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Found untraced functions such as lstm_cell_4_layer_call_fn, lstm_cell_4_layer_call_and_return_conditional_losses, lstm_cell_5_layer_call_fn, lstm_cell_5_layer_call_and_return_conditional_losses while saving (showing 4 of 4). These functions will not be directly callable after loading.
2022-05-16 11:24:43.167113: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:357] Ignored output_format.
2022-05-16 11:24:43.167139: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:360] Ignored drop_control_dependency.
2022-05-16 11:24:43.167801: I tensorflow/cc/saved_model/reader.cc:43] Reading SavedModel from: /tmp/tmp7t02n0xs
2022-05-16 11:24:43.169884: I tensorflow/cc/saved_model/reader.cc:78] Reading meta graph with tags { serve }
2022-05-16 11:24:43.169899: I tensorflow/cc/saved_model/reader.cc:119] Reading SavedModel debug info (if present) from: /tmp/tmp7t02n0xs
2022-05-16 11:24:43.178090: I tensorflow/cc/saved_model/loader.cc:228] Restoring SavedModel bundle.
2022-05-16 11:24:43.224443: I tensorflow/cc/saved_model/loader.cc:212] Running initialization op on SavedModel bundle at path: /tmp/tmp7t02n0xs
2022-05-16 11:24:43.239752: I tensorflow/cc/saved_model/loader.cc:301] SavedModel load for tags { serve }; Status: success: OK. Took 71952 microseconds.
2022-05-16 11:24:43.270659: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:237] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2022-05-16 11:24:43.320343: I tensorflow/compiler/mlir/lite/flatbuffer_export.cc:1963] Estimated count of arithmetic ops: 0.824 M  ops, equivalently 0.412 M  MACs

Estimated count of arithmetic ops: 0.824 M  ops, equivalently 0.412 M  MACs
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
WARNING:absl:Found untraced functions such as lstm_cell_6_layer_call_fn, lstm_cell_6_layer_call_and_return_conditional_losses, lstm_cell_7_layer_call_fn, lstm_cell_7_layer_call_and_return_conditional_losses while saving (showing 4 of 4). These functions will not be directly callable after loading.
2022-05-16 11:24:44.867283: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:357] Ignored output_format.
2022-05-16 11:24:44.867310: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:360] Ignored drop_control_dependency.
2022-05-16 11:24:44.867543: I tensorflow/cc/saved_model/reader.cc:43] Reading SavedModel from: /tmp/tmpfr4beuqa
2022-05-16 11:24:44.870124: I tensorflow/cc/saved_model/reader.cc:78] Reading meta graph with tags { serve }
2022-05-16 11:24:44.870141: I tensorflow/cc/saved_model/reader.cc:119] Reading SavedModel debug info (if present) from: /tmp/tmpfr4beuqa
2022-05-16 11:24:44.878121: I tensorflow/cc/saved_model/loader.cc:228] Restoring SavedModel bundle.
2022-05-16 11:24:44.909595: I tensorflow/cc/saved_model/loader.cc:212] Running initialization op on SavedModel bundle at path: /tmp/tmpfr4beuqa
2022-05-16 11:24:44.927967: I tensorflow/cc/saved_model/loader.cc:301] SavedModel load for tags { serve }; Status: success: OK. Took 60424 microseconds.
2022-05-16 11:24:45.021548: I tensorflow/compiler/mlir/lite/flatbuffer_export.cc:1963] Estimated count of arithmetic ops: 1.349 M  ops, equivalently 0.674 M  MACs

Estimated count of arithmetic ops: 1.349 M  ops, equivalently 0.674 M  MACs
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
TF lite conversion complete!

https://github.com/breizhn/DTLN/blob/1de1f15a8b5b7e1c44905618ff2ef70ca8277fbc/real_time_processing_tf_lite.py#L34-L38 L39 insert ... print(input_details_1, "\n") print(output_details_1, "\n") print(input_details_2, "\n") print(output_details_2, "\n") ... They are returned as lists not tensors? Been too long and my memory as guess tflite but sure they are tensors with tensorflow

python real_time_processing_tf_lite.py
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
[{'name': 'serving_default_input_4:0', 'index': 0, 'shape': array([  1,   1, 512], dtype=int32), 'shape_signature': array([  1,   1, 512], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'serving_default_input_5:0', 'index': 1, 'shape': array([  1,   2, 128,   2], dtype=int32), 'shape_signature': array([  1,   2, 128,   2], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}] 

[{'name': 'StatefulPartitionedCall:1', 'index': 97, 'shape': array([  1,   2, 128,   2], dtype=int32), 'shape_signature': array([  1,   2, 128,   2], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:0', 'index': 92, 'shape': array([  1,   1, 512], dtype=int32), 'shape_signature': array([  1,   1, 512], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}] 

[{'name': 'serving_default_input_3:0', 'index': 0, 'shape': array([  1,   2, 128,   2], dtype=int32), 'shape_signature': array([  1,   2, 128,   2], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'serving_default_input_2:0', 'index': 1, 'shape': array([  1,   1, 257], dtype=int32), 'shape_signature': array([  1,   1, 257], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}] 

[{'name': 'StatefulPartitionedCall:0', 'index': 64, 'shape': array([  1,   1, 257], dtype=int32), 'shape_signature': array([  1,   1, 257], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:1', 'index': 69, 'shape': array([  1,   2, 128,   2], dtype=int32), 'shape_signature': array([  1,   2, 128,   2], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}] 

Traceback (most recent call last):
  File "/home/stuart/DTLN/real_time_processing_tf_lite.py", line 74, in <module>
    interpreter_1.set_tensor(input_details_1[0]['index'], in_mag)
  File "/home/stuart/DTLN/venv/lib/python3.10/site-packages/tensorflow/lite/python/interpreter.py", line 698, in set_tensor
    self._interpreter.SetTensor(tensor_index, value)
ValueError: Cannot set tensor: Dimension mismatch. Got 257 but expected 512 for dimension 2 of input 0.

I thought ha this is simple just reverse the model loads

# load models
interpreter_1 = tflite.Interpreter(model_path='./DTLN_model_1.tflite')
interpreter_1.allocate_tensors()
interpreter_2 = tflite.Interpreter(model_path='./DTLN_model_2.tflite')
interpreter_2.allocate_tensors()

But

python real_time_processing_tf_lite.py
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
[{'name': 'serving_default_input_3:0', 'index': 0, 'shape': array([  1,   2, 128,   2], dtype=int32), 'shape_signature': array([  1,   2, 128,   2], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'serving_default_input_2:0', 'index': 1, 'shape': array([  1,   1, 257], dtype=int32), 'shape_signature': array([  1,   1, 257], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}] 

[{'name': 'StatefulPartitionedCall:0', 'index': 64, 'shape': array([  1,   1, 257], dtype=int32), 'shape_signature': array([  1,   1, 257], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:1', 'index': 69, 'shape': array([  1,   2, 128,   2], dtype=int32), 'shape_signature': array([  1,   2, 128,   2], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}] 

[{'name': 'serving_default_input_4:0', 'index': 0, 'shape': array([  1,   1, 512], dtype=int32), 'shape_signature': array([  1,   1, 512], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'serving_default_input_5:0', 'index': 1, 'shape': array([  1,   2, 128,   2], dtype=int32), 'shape_signature': array([  1,   2, 128,   2], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}] 

[{'name': 'StatefulPartitionedCall:1', 'index': 97, 'shape': array([  1,   2, 128,   2], dtype=int32), 'shape_signature': array([  1,   2, 128,   2], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:0', 'index': 92, 'shape': array([  1,   1, 512], dtype=int32), 'shape_signature': array([  1,   1, 512], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}] 

Traceback (most recent call last):
  File "/home/stuart/DTLN/real_time_processing_tf_lite.py", line 74, in <module>
    interpreter_1.set_tensor(input_details_1[0]['index'], in_mag)
  File "/home/stuart/DTLN/venv/lib/python3.10/site-packages/tensorflow/lite/python/interpreter.py", line 698, in set_tensor
    self._interpreter.SetTensor(tensor_index, value)
ValueError: Cannot set tensor: Dimension mismatch. Got 3 but expected 4 for input 0.

I have MS and I am having a not wake-up day and will have another go when feeling better :)

StuartIanNaylor commented 2 years ago

Thanks to Nils but if your as confused as my befuddled brain just use this

"""
This is an example how to implement real time processing of the DTLN tf light
model in python.
Please change the name of the .wav file at line 43 before running the sript.
For .whl files of the tf light runtime go to: 
    https://www.tensorflow.org/lite/guide/python

Author: Nils L. Westhausen (nils.westhausen@uol.de)
Version: 30.06.2020
This code is licensed under the terms of the MIT-license.
"""

import soundfile as sf
import numpy as np
import tensorflow.lite as tflite
import time

##########################
# the values are fixed, if you need other values, you have to retrain.
# The sampling rate of 16k is also fix.
block_len = 512
block_shift = 128
# load models
interpreter_1 = tflite.Interpreter(model_path='DTLN_model_1.tflite')
interpreter_1.allocate_tensors()
interpreter_2 = tflite.Interpreter(model_path='DTLN_model_2.tflite')
interpreter_2.allocate_tensors()

# Get input and output tensors.
input_details_1 = interpreter_1.get_input_details()
output_details_1 = interpreter_1.get_output_details()

input_details_2 = interpreter_2.get_input_details()
output_details_2 = interpreter_2.get_output_details()
print(input_details_1, "\n")
print(output_details_1, "\n")
print(input_details_2, "\n")
print(output_details_2, "\n")
# create states for the lstms
states_1 = np.zeros(input_details_1[0]['shape']).astype('float32')
states_2 = np.zeros(input_details_2[1]['shape']).astype('float32')
# load audio file at 16k fs (please change)
audio,fs = sf.read('fileid_10.wav')
# check for sampling rate
if fs != 16000:
    raise ValueError('This model only supports 16k sampling rate.')
# preallocate output audio
out_file = np.zeros((len(audio)))
# create buffer
in_buffer = np.zeros((block_len)).astype('float32')
out_buffer = np.zeros((block_len)).astype('float32')
# calculate number of blocks
num_blocks = (audio.shape[0] - (block_len-block_shift)) // block_shift
time_array = []      
# iterate over the number of blcoks  
for idx in range(num_blocks):
    start_time = time.time()
    # shift values and write to buffer
    in_buffer[:-block_shift] = in_buffer[block_shift:]
    in_buffer[-block_shift:] = audio[idx*block_shift:(idx*block_shift)+block_shift]
    # calculate fft of input block
    in_block_fft = np.fft.rfft(in_buffer)
    in_mag = np.abs(in_block_fft)
    in_phase = np.angle(in_block_fft)
    # reshape magnitude to input dimensions
    in_mag = np.reshape(in_mag, (1,1,-1)).astype('float32')
    # set tensors to the first model
    print(np.shape(states_1), np.shape(in_mag))
    interpreter_1.set_tensor(input_details_1[0]['index'], states_1)
    interpreter_1.set_tensor(input_details_1[1]['index'], in_mag)
    # run calculation 
    interpreter_1.invoke()
    # get the output of the first block
    out_mask = interpreter_1.get_tensor(output_details_1[0]['index']) 
    states_1 = interpreter_1.get_tensor(output_details_1[1]['index'])
    print(np.shape(out_mask), np.shape(states_1))
    # calculate the ifft
    estimated_complex = in_mag * out_mask * np.exp(1j * in_phase)
    estimated_block = np.fft.irfft(estimated_complex)
    # reshape the time domain block
    estimated_block = np.reshape(estimated_block, (1,1,-1)).astype('float32')
    # set tensors to the second block
    print(np.shape(states_2), np.shape(estimated_block))
    interpreter_2.set_tensor(input_details_2[1]['index'], states_2)
    interpreter_2.set_tensor(input_details_2[0]['index'], estimated_block)
    # run calculation
    interpreter_2.invoke()
    # get output tensors
    out_block = interpreter_2.get_tensor(output_details_2[1]['index']) 
    states_2 = interpreter_2.get_tensor(output_details_2[0]['index'])
    print(np.shape(out_block), np.shape(states_2))

    # shift values and write to buffer
    out_buffer[:-block_shift] = out_buffer[block_shift:]
    out_buffer[-block_shift:] = np.zeros((block_shift))
    out_buffer  += np.squeeze(out_block)
    # write block to output file
    out_file[idx*block_shift:(idx*block_shift)+block_shift] = out_buffer[:block_shift]
    time_array.append(time.time()-start_time)

# write to .wav file 
sf.write('out.wav', out_file, fs) 
print('Processing Time [ms]:')
print(np.mean(np.stack(time_array))*1000)
print('Processing finished.')