rolczynski / Automatic-Speech-Recognition

🎧 Automatic Speech Recognition: DeepSpeech & Seq2Seq (TensorFlow)
GNU Affero General Public License v3.0
223 stars 63 forks source link

Where to add tf.keras.backend.clear_session()? #26

Closed sunlanchang closed 4 years ago

sunlanchang commented 4 years ago

Recently I put your model behind my flask web server. Every time a POST/GET comes to Flask server, the server will execute pipeline.predict([sample]). After couples of POST/GET requests, the server will be out of memory, so I would like to add tf.keras.backend.clear_session() to your code. I am trying to put tf.keras.backend.clear_session() to some position, but it seems like not work. I would be greatly appreciated if you help solve the problem.

Some Flask server code:

@app.route('/test', methods=['POST'])
def hello_world():
    ....
    audio_file = request.files['audio']
    audio_file.save(audio_path)
    sample = asr.utils.read_audio(audio_path)
    pipeline = asr.load('deepspeech2', lang='en')
    sentences = pipeline.predict([sample])
    ....
    return sentences
rolczynski commented 4 years ago

hey @sunlanchang My first thought is to put outside a pipeline initialization outside of the POST call. What do you think?

sunlanchang commented 4 years ago

Actually I have tried 3 ways.

  1. put it outside of the POST call like you said.
  2. put it into ctc_pipeline.py.
  3. put it into deepspeech2.py

None of these works for me, where do you think i can put ? Or maybe i should try put clear_session() into every .py file....

sunlanchang commented 4 years ago

hey @rolczynski Finally I solved the problem, instead of put clear_session() to some position I just move pipeline = asr.load('deepspeech2', lang='en') outside of the POST call. In this way computing graph will be created only one time. Thanks a lot.

pipeline = asr.load('deepspeech2', lang='en')
@app.route('/test', methods=['POST'])
def hello_world():
    ....
    audio_file = request.files['audio']
    audio_file.save(audio_path)
    sample = asr.utils.read_audio(audio_path)
    sentences = pipeline.predict([sample])
    ....
    return sentences