openly-jp / voiscribe

1 stars 0 forks source link

introduce `ModelManager` class #282

Open ooyamatakehisa opened 1 year ago

ooyamatakehisa commented 1 year ago

reference: https://github.com/openly-jp/voiscribe/pull/279#discussion_r1193271631

 import Foundation

 class ModelManager: ObservableObject {
    private let models: [Size:[ModelLanguage:WhisperModel]]
    @Published var downloadProgressValue: [String:Float] = [:]

    init() {
        var tmpModels: [Size:[ModelLanguage:WhisperModel]] = [:]
        for size in Size.allCases {
            tmpModels[size] = [:]
            for modelLanguage in ModelLanguage.allCases {
                tmpModels[size]![modelLanguage] = WhisperModel(
                    size: size,
                    modelLanguage: modelLanguage
                )
            }
        }
        models = tmpModels
    }

    func getModel(size: Size, recognitionLanguage: RecognitionLanguage) -> WhisperModel {
        models[size][WhisperModel.getModelLanguage(recognitionLanguage)]
    }

    func donwloadMoel(
        model: WhisperModel,
        completeCallback: @escaping (Error?) -> Void
    ) {
        if model.isDownloaded {
            throw NSError(
                domain: "The model is already downloaded.",
                code: -1
            )
        }

        downloadProgressValue[model.name] = 0

        WhisperModelRepository.fetchWhisperModel(
            size: size,
            language: language,
            update: { downloadProgressValue[model.name] = $0 },
            destinationURL: localPath
        ) { result in
            var err: Error?

            switch result {
            case .success:
                DispatchQueue.main.async { model.isDownloaded = true }
                downloadProgressValue[model.name] = nil
            case let .failure(error):
                model.isDownloaded = false
                err = NSError(
                    domain: "Failed to download model: \(error.localizedDescription)",
                    code: -1
                )
            }

            completeCallback(err)
        }
    }

 }