Kitt-AI / snowboy

Future versions with model training module will be maintained through a forked version here: https://github.com/seasalt-ai/snowboy
Other
3.04k stars 995 forks source link

Hotword detection .pmdl downloaded with REST API seems recognize all word as my hotword #521

Open Orma07 opened 5 years ago

Orma07 commented 5 years ago

Hi, I downloaded example for android, I would to add call to rest to create my own .pmdl, so I add

` class TrainingModel{ var name: String? = null var language: String? = null var age_group: String? = null var gender: String? = null var microphone: String? = "android microphone" var token: String = Constants.REST_KITT.TOKEN var voice_samples : ArrayList? = null }

class WaveModel{ var wave : String? = null } `

and

` class Trainer(listener: ITrainerListener, isFirst : Boolean) {

private var retrofit: Retrofit? = null
private var api : KittApi? = null
private var listener:ITrainerListener = listener
val isFirst = isFirst

companion object {
    val TAG = "Trainer"
}

init {

    val gson = GsonBuilder()
            .setLenient()
            .create()

    val logging = HttpLoggingInterceptor()
    logging.level = HttpLoggingInterceptor.Level.BODY

    val httpClient = OkHttpClient.Builder()
    httpClient.addInterceptor { chain ->
        val requestBuilder = chain.request().newBuilder()
        requestBuilder.header("Content-Type", "application/json")
        chain.proceed(requestBuilder.build())
    }

    httpClient.addInterceptor(logging)

    retrofit = Retrofit.Builder()
            .baseUrl(Constants.REST_KITT.BASE_URL)
            .client(httpClient.build())
            .addConverterFactory(GsonConverterFactory.create(gson))
            .build()

    api = retrofit?.create(KittApi::class.java)
}

fun train(model : TrainingModel){
    var call = api?.train(model)
    call?.enqueue(object : Callback<ResponseBody> {
        override fun onResponse(call: retrofit2.Call<ResponseBody>, response: Response<ResponseBody>) {
            if (response.isSuccessful()) {
                var result = writeResponseBodyToDisk(response.body())
                if(result){
                    listener?.OnSucces()
                }
            } else {
                val  message = "trainer response return error code ${response.code()}"
                listener?.OnError(message)
                Log.e(TAG, message)
            }
        }

        override fun onFailure(call: retrofit2.Call<ResponseBody>, t: Throwable) {
            val  message = "ex ${t.message}"
            listener?.OnError(message)
        }
    })
}

private fun writeResponseBodyToDisk(body: ResponseBody): Boolean {
    try {
        var fileName = Constants.REST_KITT.BASTARD_WORD_ACTIVATE_PMDL_REC
        if(!isFirst)
            fileName = Constants.REST_KITT.BASTARD_WORD_DISACTIVATE_PMDL_REC
        val futureStudioIconFile = File(fileName)

        var inputStream: InputStream? = null
        var outputStream: OutputStream? = null

        try {
            val fileReader = ByteArray(4096)

            val fileSize = body.contentLength()
            var fileSizeDownloaded: Long = 0

            inputStream = body.byteStream()
            outputStream = FileOutputStream(futureStudioIconFile)

            while (true) {
                val read = inputStream?.read(fileReader) ?: -1

                if (read == -1) {
                    break
                }

                outputStream!!.write(fileReader, 0, read)

                fileSizeDownloaded += read.toLong()

                Log.d(TAG, "file download: $fileSizeDownloaded of $fileSize")
            }

            outputStream!!.flush()

            return true
        } catch (e: IOException) {
            return false
        } finally {
            if (inputStream != null) {
                inputStream!!.close()
            }

            if (outputStream != null) {
                outputStream!!.close()
            }
        }
    } catch (e: IOException) {
        return false
    }

}

} `

this how I run hotword detection:

` private fun initDetector() { AppResCopy.copyResFromAssetsToSD(fragment?.context?.applicationContext)

    // defining the file paths in assets folder
    var recModel = Constants.REST_KITT.BASTARD_WORD_ACTIVATE_PMDL_REC
    if(!trainer.isFirst){
        recModel = Constants.REST_KITT.BASTARD_WORD_DISACTIVATE_PMDL_REC
    }
    detector = SnowboyDetect(commonRes, recModel)
    detector?.SetSensitivity("1")
    detector?.SetAudioGain(1f)
    detector?.ApplyFrontend(false)
}

private fun startDetect() {
    shouldContinue = true
    Thread {
        android.os.Process.setThreadPriority(android.os.Process.THREAD_PRIORITY_AUDIO)
        // Buffer size in bytes: for 0.1 second of audio
        val bufferSize = (ai.kitt.snowboy.Constants.SAMPLE_RATE * 0.1 * 2)
        val audioBuffer = ByteArray(bufferSize.toInt())
        val record = AudioRecord(
                MediaRecorder.AudioSource.DEFAULT,
                ai.kitt.snowboy.Constants.SAMPLE_RATE, // sample rate is 16000
                CHANNEL_IN_MONO,
                ENCODING_PCM_16BIT,
                bufferSize.toInt())

        record.startRecording()
        detector?.Reset()
        while (shouldContinue) {

            record.read(audioBuffer, 0, audioBuffer.size)
            runDetection(audioBuffer)
        }

        record.stop()
        record.release()
    }.start()
}

private fun runDetection(audioBuffer: ByteArray) {

    val audioData = ShortArray(audioBuffer.size / 2)
    ByteBuffer.wrap(audioBuffer).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(audioData)

    // Snowboy hotword detection.
    val result = detector?.RunDetection(audioData, audioData.size)

    if(result != null) {
        if (result == -2) {
            // no speech
        } else if (result == -1) {
            // onerror
        } else if (result == 0) {
            // normal speech
        } else if (result == 1) {
            fragment?.setRecCompleted(END_OF_ATTEMPTS)
            shouldContinue = false
            recordIterate++
        }
    }
}

`

To any word pronunced result is one, for example if I recorded 3 times "hello" an I pronunce "banana" I get result == 1, have any Idea?

naveensingh commented 5 years ago

You have set sensitivity to 1, change it to something like 0.5 or 0.7( less than one).