jianchang512 / stt

Voice Recognition to Text Tool / 一个离线运行的本地语音识别转文字服务,输出json、srt字幕带时间戳、纯文字格式
https://pyvideotrans.com
GNU General Public License v3.0
2.21k stars 247 forks source link

可以写一个高并发的start.py吗 #84

Open Rosemajor opened 1 week ago

Rosemajor commented 1 week ago

` import concurrent.futures import logging import re import threading import sys import torch from flask import Flask, request, render_template, jsonify, send_from_directory import os from gevent.pywsgi import WSGIServer, WSGIHandler, LoggingLogAdapter from logging.handlers import RotatingFileHandler import warnings from waitress import serve warnings.filterwarnings('ignore') import stslib from stslib import cfg, tool from stslib.cfg import ROOT_DIR from faster_whisper import WhisperModel

app = Flask(name, static_folder=os.path.join(ROOT_DIR, 'static'), static_url_path='/static', template_folder=os.path.join(ROOT_DIR, 'templates'))

配置线程池

executor = concurrent.futures.ThreadPoolExecutor(max_workers=10) # 根据服务器配置调整 max_workers class CustomRequestHandler(WSGIHandler): def log_request(self): pass

def process_audio_task(audio_file, model, language, response_format): try: noextname, ext = os.path.splitext(audio_file.filename) ext = ext.lower() wav_file = os.path.join(cfg.TMP_DIR, f'{noextname}.wav')

    if not os.path.exists(wav_file) or os.path.getsize(wav_file) == 0:
        if ext in ['.mp4', '.mov', '.avi', '.mkv', '.mpeg', '.mp3', '.flac']:
            video_file = os.path.join(cfg.TMP_DIR, f'{noextname}{ext}')
            audio_file.save(video_file)
            params = ['-i', video_file]
            if ext not in ['.mp3', '.flac']:
                params.append('-vn')
            params.append(wav_file)
            rs = tool.runffmpeg(params)
            if rs != 'ok':
                return {"code": 1, "msg": rs}
        elif ext == '.wav':
            audio_file.save(wav_file)
        else:
            return {"code": 1, "msg": f"{cfg.transobj['lang3']} {ext}"}

    sets = cfg.parse_ini()

    try:
        model_instance = WhisperModel(
            model,
            device=sets.get('devtype'),
            compute_type=sets.get('cuda_com_type'),
            download_root=cfg.ROOT_DIR + "/models",
            local_files_only=False if model.find('/') > 0 else True
        )
    except Exception as e:
        err = f'从huggingface.co下载模型 {model} 失败,请检查网络连接' if model.find('/') > 0 else ''
        return {"code": 1, "msg": f"{err} {e}"}

    segments, info = model_instance.transcribe(
        wav_file,
        beam_size=sets.get('beam_size'),
        best_of=sets.get('best_of'),
        temperature=0 if sets.get('temperature') == 0 else [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
        condition_on_previous_text=sets.get('condition_on_previous_text'),
        vad_filter=sets.get('vad'),
        vad_parameters=dict(min_silence_duration_ms=300, max_speech_duration_s=10.5),
        language=language if language != 'auto' else None,
        initial_prompt=sets.get('initial_prompt_zh') if language == 'zh' else None
    )

    raw_subtitles = []
    for segment in segments:
        start = int(segment.start * 1000)
        end = int(segment.end * 1000)
        startTime = tool.ms_to_time_string(ms=start)
        endTime = tool.ms_to_time_string(ms=end)
        text = segment.text.strip().replace(''', "'")
        text = re.sub(r'&#\d+;', '', text)

        if not text or re.match(r'^[,。、?‘’“”;:({}【】):;"\'\s \d`!@#$%^&*()_+=.,?/\\-]*$', text) or len(text) <= 1:
            continue

        if response_format == 'json':
            raw_subtitles.append(
                {"line": len(raw_subtitles) + 1, "start_time": startTime, "end_time": endTime, "text": text})
        elif response_format == 'text':
            raw_subtitles.append(text)
        else:
            raw_subtitles.append(f'{len(raw_subtitles) + 1}\n{startTime} --> {endTime}\n{text}\n')

    if response_format != 'json':
        raw_subtitles = "\n".join(raw_subtitles)

    return {"code": 0, "msg": 'ok', "data": raw_subtitles}
except Exception as e:
    app.logger.error(f'[api]error: {e}')
    return {'code': 2, 'msg': str(e)}

def _is_model_exists(model): if model.find('/')>0: return True if not model.startswith('distil') and not os.path.exists(os.path.join(cfg.MODEL_DIR, f'models--Systran--faster-whisper-{model}/snapshots/')): return False if model.startswith('distil') and not os.path.exists(os.path.join(cfg.MODEL_DIR, f'models--Systran--faster-{model}/snapshots/')): return False

return True

@app.route('/api', methods=['GET', 'POST']) def api(): try:

获取上传的文件和表单数据

    audio_file = request.files['file']
    model = request.form.get("model")
    language = request.form.get("language")
    response_format = request.form.get("response_format", 'srt')

    # 检查模型是否存在
    if _is_model_exists(model) is not True:
        return jsonify({"code": 1, "msg": f"{model} {cfg.transobj['lang4']}"})
    # 提交任务到线程池,异步处理
    future = executor.submit(process_audio_task, audio_file, model, language, response_format)
    result = future.result()  # 同步等待结果
    return jsonify(result)

except Exception as e:
    app.logger.error(f'[api]error: {e}')
    return jsonify({'code': 2, 'msg': str(e)})

if name == 'main':

其余代码保持不变

http_server = None
try:
    threading.Thread(target=tool.checkupdate).start()
    try:
        if cfg.devtype == 'cpu':
            print(
                '\n如果设备使用英伟达显卡并且CUDA环境已正确安装,可修改set.ini中\ndevtype=cpu 为 devtype=cuda, 然后重新启动以加快识别速度\n')
        host = cfg.web_address.split(':')
        http_server = WSGIServer((host[0], int(host[1])), app, handler_class=CustomRequestHandler)
        threading.Thread(target=tool.openweb, args=(cfg.web_address,)).start()
        http_server.serve_forever()
    finally:
        if http_server:
            http_server.stop()
except Exception as e:
    if http_server:
        http_server.stop()
    app.logger.error(f"[app]start error:{str(e)}")
    print("error:" + str(e))

` 上面是我想改成多进程的,但是失败了,还是单线程的,作者大大可以加一个并发处理的嘛,并发量可以根据自己电脑配置改动