k2-fsa / sherpa-onnx

Speech-to-text, text-to-speech, speaker diarization, and VAD using next-gen Kaldi with onnxruntime without Internet connection. Support embedded systems, Android, iOS, Raspberry Pi, RISC-V, x86_64 servers, websocket server/client, C/C++, Python, Kotlin, C#, Go, NodeJS, Java, Swift, Dart, JavaScript, Flutter, Object Pascal, Lazarus, Rust
https://k2-fsa.github.io/sherpa/onnx/index.html
Apache License 2.0
3.67k stars 426 forks source link

TTS Highlight Current Text #585

Open studionexus-lk opened 9 months ago

studionexus-lk commented 9 months ago

cant we implement a method to detect the status of tts? currently speaking text that can be use to highlight the speaking text?

csukuangfj commented 9 months ago

Could you describe how you use the tts function of sherpa-onnx?

If you integrate it into your app, i.e., use the tts engine service, then I think it is feasible. You can send sentence by sentence to sherpa-onnx. You can highlight the current sentence that is being synthesized.

studionexus-lk commented 9 months ago

I'm referring to standalone version without using tts engine installation, isn't it possible to detect the current speaking text to highlight or just give a view like lyrics

csukuangfj commented 9 months ago

isn't it possible to detect the current speaking text to highlight

Yes, I think it is possible. Would you like to contribute?

The idea is

  1. Split the text into sentences
  2. For the current sentence, highlight it, synthesize it, play it, de-highlight it, and go to 3
  3. Process the next sentence
studionexus-lk commented 9 months ago
package com.k2fsa.sherpa.onnx

import android.content.res.AssetManager

data class OfflineTtsVitsModelConfig(
    var model: String,
    var lexicon: String = "",
    var tokens: String,
    var dataDir: String = "",
    var noiseScale: Float = 0.667f,
    var noiseScaleW: Float = 0.8f,
    var lengthScale: Float = 1.0f,
)

data class OfflineTtsModelConfig(
    var vits: OfflineTtsVitsModelConfig,
    var numThreads: Int = 1,
    var debug: Boolean = false,
    var provider: String = "cpu",
)

data class OfflineTtsConfig(
    var model: OfflineTtsModelConfig,
    var ruleFsts: String = "",
    var maxNumSentences: Int = 1,
)

class GeneratedAudio(
    val samples: FloatArray,
    val sampleRate: Int,
) {
    fun save(filename: String) =
        saveImpl(filename = filename, samples = samples, sampleRate = sampleRate)

    private external fun saveImpl(
        filename: String,
        samples: FloatArray,
        sampleRate: Int
    ): Boolean
}

class OfflineTts(
    assetManager: AssetManager? = null,
    var config: OfflineTtsConfig,
    var highlightCallback: ((highlightedText: String) -> Unit)? = null // Add a callback property
) {
    private var ptr: Long

    init {
        if (assetManager != null) {
            ptr = new(assetManager, config)
        } else {
            ptr = newFromFile(config)
        }
    }

    fun sampleRate() = getSampleRate(ptr)

    fun numSpeakers() = getNumSpeakers(ptr)

    fun generate(
        text: String,
        sid: Int = 0,
        speed: Float = 1.0f
    ): GeneratedAudio {
        var objArray = generateImpl(ptr, text = text, sid = sid, speed = speed)
        return GeneratedAudio(
            samples = objArray[0] as FloatArray,
            sampleRate = objArray[1] as Int
        )
    }

    fun generateWithCallback(
        text: String,
        sid: Int = 0,
        speed: Float = 1.0f,
        callback: (samples: FloatArray) -> Unit
    ): GeneratedAudio {
        var objArray = generateWithCallbackImpl(ptr, text = text, sid = sid, speed = speed) { samples ->
            // Invoke the provided callback function
            callback(samples)
            // Invoke the highlight callback with the spoken text
            highlightCallback?.invoke(text)
        }
        return GeneratedAudio(
            samples = objArray[0] as FloatArray,
            sampleRate = objArray[1] as Int
        )
    }

    fun allocate(assetManager: AssetManager? = null) {
        if (ptr == 0L) {
            if (assetManager != null) {
                ptr = new(assetManager, config)
            } else {
                ptr = newFromFile(config)
            }
        }
    }

    fun free() {
        if (ptr != 0L) {
            delete(ptr)
            ptr = 0
        }
    }

    protected fun finalize() {
        delete(ptr)
    }

    private external fun new(
        assetManager: AssetManager,
        config: OfflineTtsConfig,
    ): Long

    private external fun newFromFile(
        config: OfflineTtsConfig,
    ): Long

    private external fun delete(ptr: Long)
    private external fun getSampleRate(ptr: Long): Int
    private external fun getNumSpeakers(ptr: Long): Int

    // The returned array has two entries:
    //  - the first entry is an 1-D float array containing audio samples.
    //    Each sample is normalized to the range [-1, 1]
    //  - the second entry is the sample rate
    external fun generateImpl(
        ptr: Long,
        text: String,
        sid: Int = 0,
        speed: Float = 1.0f
    ): Array<Any>

    external fun generateWithCallbackImpl(
        ptr: Long,
        text: String,
        sid: Int = 0,
        speed: Float = 1.0f,
        callback: (samples: FloatArray) -> Unit
    ): Array<Any>

    companion object {
        init {
            System.loadLibrary("sherpa-onnx-jni")
        }
    }

}

// please refer to
// https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/index.html
// to download models
fun getOfflineTtsConfig(
    modelDir: String,
    modelName: String,
    lexicon: String,
    dataDir: String,
    ruleFsts: String
): OfflineTtsConfig? {
    return OfflineTtsConfig(
        model = OfflineTtsModelConfig(
            vits = OfflineTtsVitsModelConfig(
                model = "$modelDir/$modelName",
                lexicon = "$modelDir/$lexicon",
                tokens = "$modelDir/tokens.txt",
                dataDir = "$dataDir"
            ),
            numThreads = 2,
            debug = true,
            provider = "cpu",
        ),
        ruleFsts = ruleFsts,
    )
}

i added fake highlight based on audio, by making some changes

package com.k2fsa.sherpa.onnx

import android.Manifest
import android.content.Intent
import android.content.pm.PackageManager
import android.content.res.AssetManager
import android.graphics.Color
import android.media.*
import android.net.Uri
import android.os.AsyncTask
import android.os.Bundle
import android.os.Handler
import android.text.Spannable
import android.text.SpannableString
import android.text.style.BackgroundColorSpan
import android.text.style.ForegroundColorSpan
import android.util.Log
import android.widget.Button
import android.widget.EditText
import android.widget.TextView
import android.widget.Toast
import androidx.activity.result.contract.ActivityResultContracts
import androidx.appcompat.app.AppCompatActivity
import androidx.core.content.ContextCompat
import java.io.File
import java.io.FileInputStream
import java.io.FileOutputStream
import java.io.IOException
import java.util.Timer
import java.util.TimerTask

const val TAG = "sherpa-onnx"

class MainActivity : AppCompatActivity() {
    private lateinit var tts: OfflineTts
    private lateinit var text: EditText
    private lateinit var sid: EditText
    private lateinit var speed: EditText
    private lateinit var generate: Button
    private lateinit var play: Button
    private lateinit var export: Button
    private lateinit var useTtsButton: Button
    private lateinit var textView:TextView

    // see
    // https://developer.android.com/reference/kotlin/android/media/AudioTrack
    private lateinit var track: AudioTrack

    private val requestPermissionLauncher =
        registerForActivityResult(ActivityResultContracts.RequestPermission()) { isGranted: Boolean ->
            if (isGranted) {
                exportFile()
            } else {
                Toast.makeText(
                    this,
                    "Permission denied, cannot export file.",
                    Toast.LENGTH_SHORT
                ).show()
            }
        }

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        Log.i(TAG, "Start to initialize TTS")
        initTts()
        Log.i(TAG, "Finish initializing TTS")

        Log.i(TAG, "Start to initialize AudioTrack")
        initAudioTrack()
        Log.i(TAG, "Finish initializing AudioTrack")

        text = findViewById(R.id.text)
        sid = findViewById(R.id.sid)
        speed = findViewById(R.id.speed)
        export= findViewById(R.id.export)
        textView= findViewById(R.id.textView)
        generate = findViewById(R.id.generate)
        play = findViewById(R.id.play)

        generate.setOnClickListener { onClickGenerate() }
        play.setOnClickListener { onClickPlay() }

        useTtsButton = findViewById(R.id.use_tts_button)
        useTtsButton.setOnClickListener { onClickUseTts() }

        export.setOnClickListener {
            if (ContextCompat.checkSelfPermission(
                    this,
                    Manifest.permission.WRITE_EXTERNAL_STORAGE
                ) != PackageManager.PERMISSION_GRANTED
            ) {
                // Permission is not granted, request it
                requestPermissionLauncher.launch(Manifest.permission.WRITE_EXTERNAL_STORAGE)
            } else {
                // Permission is granted, export the file
                exportFile()
            }
        }

        sid.setText("0")
        speed.setText("1.0")

        // we will change sampleText here in the CI
        val sampleText = ""
        text.setText(sampleText)

        play.isEnabled = false
    }

    private fun onClickUseTts() {
        // Retrieve text from EditText or use a default text
        val textToSpeak = text.text.toString().trim()

    }

    private fun initAudioTrack() {
        val sampleRate = tts.sampleRate()
        val bufLength = AudioTrack.getMinBufferSize(
            sampleRate,
            AudioFormat.CHANNEL_OUT_MONO,
            AudioFormat.ENCODING_PCM_FLOAT
        )
        Log.i(TAG, "sampleRate: ${sampleRate}, buffLength: ${bufLength}")

        val attr = AudioAttributes.Builder().setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
            .setUsage(AudioAttributes.USAGE_MEDIA)
            .build()

        val format = AudioFormat.Builder()
            .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
            .setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
            .setSampleRate(sampleRate)
            .build()

        track = AudioTrack(
            attr, format, bufLength, AudioTrack.MODE_STREAM,
            AudioManager.AUDIO_SESSION_ID_GENERATE
        )
        track.play()
    }

    // this function is called from C++
    private fun callback(samples: FloatArray) {
        track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
    }

    private fun onClickGenerate() {
        val sidInt = sid.text.toString().toIntOrNull()
        if (sidInt == null || sidInt < 0) {
            Toast.makeText(
                applicationContext,
                "Please input a non-negative integer for speaker ID!",
                Toast.LENGTH_SHORT
            ).show()
            return
        }

        val speedFloat = speed.text.toString().toFloatOrNull()
        if (speedFloat == null || speedFloat <= 0) {
            Toast.makeText(
                applicationContext,
                "Please input a positive number for speech speed!",
                Toast.LENGTH_SHORT
            ).show()
            return
        }

        val textStr = text.text.toString().trim()
        if (textStr.isBlank() || textStr.isEmpty()) {
            Toast.makeText(
                applicationContext,
                "Please input a non-empty text!",
                Toast.LENGTH_SHORT
            ).show()
            return
        }

        track.pause()
        track.flush()

        // Extract metadata from the input text
        val metadata = "Text: $textStr"

        play.isEnabled = false
        export.isEnabled = false
        Thread {
            val audio = tts.generateWithCallback(
                text = metadata, // Pass metadata instead of original text
                sid = sidInt,
                speed = speedFloat,
                callback = this::callback
            )

            val filename = application.filesDir.absolutePath + "/generated.wav"
            val ok = audio.samples.size > 0 && audio.save(filename)
            if (ok) {
                runOnUiThread {
                    play.isEnabled = true
                    export.isEnabled = true
                    track.stop()
                }
            }
        }.start()
    }

    private fun onClickPlay() {
        val spokenText = text.text.toString()
        val filename = application.filesDir.absolutePath + "/generated.wav"
        val mediaPlayer = MediaPlayer.create(applicationContext, Uri.fromFile(File(filename)))

        // Get the duration of the audio in milliseconds
        val audioDuration = mediaPlayer.duration

        // Calculate the total number of words in the spoken text
        val totalWords = spokenText.split("\\s+".toRegex()).size

        // Calculate the delay between highlighting updates based on the audio duration and total words
        val highlightDelay = if (totalWords > 0) {
            (audioDuration / totalWords).toLong()
        } else {
            // Default to a small delay if there are no words
            100L
        }

        // Start playing the audio
        mediaPlayer.start()

        // Start highlighting from the beginning of the text
        var highlightStart = 0

        // Handler to schedule highlighting updates
        val handler = Handler()

        // Runnable for highlighting updates
        val highlightRunnable = object : Runnable {
            override fun run() {
                if (highlightStart < spokenText.length) {
                    // Find the end index of the current word
                    var highlightEnd = highlightStart
                    while (highlightEnd < spokenText.length && spokenText[highlightEnd] != ' ') {
                        highlightEnd++
                    }
                    // Apply highlighting to the current word
                    val spannableString = SpannableString(spokenText)
                    spannableString.setSpan(
                        BackgroundColorSpan(Color.YELLOW), // Set background color for highlighting
                        highlightStart,
                        highlightEnd,
                        Spannable.SPAN_INCLUSIVE_EXCLUSIVE
                    )
                    text.setText(spannableString, TextView.BufferType.SPANNABLE) // Update text with highlighting

                    // Move to the next word
                    highlightStart = highlightEnd + 1 // Add 1 for the space between words
                }

                // Schedule the next highlighting update after the calculated delay
                handler.postDelayed(this, highlightDelay)
            }
        }

        // Start the highlighting process
        handler.post(highlightRunnable)

        // Stop the highlighting process when the audio playback completes
        mediaPlayer.setOnCompletionListener {
            handler.removeCallbacks(highlightRunnable)
            // Optionally, you can reset the text view's appearance after highlighting completes

        }
    }

    private fun exportFile() {
        // Replace this with your actual filename
        val filename = "generated.wav"
        val file = File(filesDir, filename)

        // Check if the file exists
        if (!file.exists()) {
            Toast.makeText(this, "File does not exist.", Toast.LENGTH_SHORT).show()
            return
        }

        val intent = Intent(Intent.ACTION_CREATE_DOCUMENT).apply {
            addCategory(Intent.CATEGORY_OPENABLE)
            type = "audio/wav"
            putExtra(Intent.EXTRA_TITLE, "exported.wav")
        }

        exportActivityResult.launch(intent)
    }

    private val exportActivityResult =
        registerForActivityResult(ActivityResultContracts.StartActivityForResult()) { result ->
            if (result.resultCode == RESULT_OK) {
                val intentData = result.data
                intentData?.let { data ->
                    data.data?.let { uri ->
                        // Start a background task to export the file
                        ExportTask(uri).execute()
                    }
                }
            }
        }

    private inner class ExportTask(private val uri: Uri) :
        AsyncTask<Void, Void, Boolean>() {
        @Suppress("DEPRECATION")
        override fun doInBackground(vararg params: Void?): Boolean {
            val file = File(filesDir, "generated.wav")
            return try {
                val inputStream = FileInputStream(file)
                contentResolver.openOutputStream(uri)?.use { outputStream ->
                    inputStream.copyTo(outputStream)
                }
                true
            } catch (e: Exception) {
                e.printStackTrace()
                false
            }
        }
        @Suppress("DEPRECATION")
        override fun onPostExecute(result: Boolean) {
            super.onPostExecute(result)
            if (result) {
                Toast.makeText(
                    applicationContext,
                    "File exported successfully.",
                    Toast.LENGTH_SHORT
                ).show()
            } else {
                Toast.makeText(
                    applicationContext,
                    "Failed to export file.",
                    Toast.LENGTH_SHORT
                ).show()
            }
        }
    }
    private fun initTts() {
        var modelDir: String?
        var modelName: String?
        var ruleFsts: String?
        var lexicon: String?
        var dataDir: String?
        var assets: AssetManager? = application.assets

        // The purpose of such a design is to make the CI test easier
        // Please see
        // https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/apk/generate-tts-apk-script.py
        modelDir = null
        modelName = null
        ruleFsts = null
        lexicon = null
        dataDir = null

        // Example 1:
        // modelDir = "vits-vctk"
        // modelName = "vits-vctk.onnx"
        // lexicon = "lexicon.txt"

        // Example 2:
        // https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models
        // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
          modelDir = "vits-piper-en_US-kusal-medium"
          modelName = "vits-piper-en_US-kusal.onnx"
          dataDir = "vits-piper-en_US-kusal-medium/espeak-ng-data"

        // Example 3:
        // modelDir = "vits-zh-aishell3"
        // modelName = "vits-aishell3.onnx"
        // ruleFsts = "vits-zh-aishell3/rule.fst"
        // lexicon = "lexicon.txt"

        if (dataDir != null) {
            val newDir = copyDataDir(modelDir)
            modelDir = newDir + "/" + modelDir
            dataDir = newDir + "/" + dataDir
            assets = null
        }

        val config = getOfflineTtsConfig(
            modelDir = modelDir!!, modelName = modelName!!, lexicon = lexicon ?: "",
            dataDir = dataDir ?: "",
            ruleFsts = ruleFsts ?: ""
        )!!

        tts = OfflineTts(assetManager = assets, config = config)

    }

    private fun callback(samples: FloatArray, highlightedText: String) {
        // Update the TextView with the highlighted text
        runOnUiThread {
            textView.text = highlightedText // Update the TextView with the highlighted text
            // For example, set the text color of the TextView to indicate highlighting
            textView.setTextColor(Color.RED) // Set text color to red for highlighting
        }
        track.write(samples, 0, samples.size, AudioTrack.WRITE_BLOCKING)
    }

    private fun copyDataDir(dataDir: String): String {
        println("data dir is $dataDir")
        copyAssets(dataDir)

        val newDataDir = application.getExternalFilesDir(null)!!.absolutePath
        println("newDataDir: $newDataDir")
        return newDataDir
    }

    private fun copyAssets(path: String) {
        val assets: Array<String>?
        try {
            assets = application.assets.list(path)
            if (assets!!.isEmpty()) {
                copyFile(path)
            } else {
                val fullPath = "${application.getExternalFilesDir(null)}/$path"
                val dir = File(fullPath)
                dir.mkdirs()
                for (asset in assets.iterator()) {
                    val p: String = if (path == "") "" else path + "/"
                    copyAssets(p + asset)
                }
            }
        } catch (ex: IOException) {
            Log.e(TAG, "Failed to copy $path. ${ex.toString()}")
        }
    }

    private fun copyFile(filename: String) {
        try {
            val istream = application.assets.open(filename)
            val newFilename = application.getExternalFilesDir(null).toString() + "/" + filename
            val ostream = FileOutputStream(newFilename)
            // Log.i(TAG, "Copying $filename to $newFilename")
            val buffer = ByteArray(1024)
            var read = 0
            while (read != -1) {
                ostream.write(buffer, 0, read)
                read = istream.read(buffer)
            }
            istream.close()
            ostream.flush()
            ostream.close()
        } catch (ex: Exception) {
            Log.e(TAG, "Failed to copy $filename, ${ex.toString()}")
        }
    }
}
studionexus-lk commented 9 months ago

below code gives a fake text highlight to the edit text based on legth of the audio, any ideas to impliment non fake highlighter?if so let me know your suggestions

  private fun onClickPlay() {
        val spokenText = text.text.toString()
        val filename = application.filesDir.absolutePath + "/generated.wav"
        val mediaPlayer = MediaPlayer.create(applicationContext, Uri.fromFile(File(filename)))

        // Get the duration of the audio in milliseconds
        val audioDuration = mediaPlayer.duration

        // Calculate the total number of words in the spoken text
        val totalWords = spokenText.split("\\s+".toRegex()).size

        // Calculate the delay between highlighting updates based on the audio duration and total words
        val highlightDelay = if (totalWords > 0) {
            (audioDuration / totalWords).toLong()
        } else {
            // Default to a small delay if there are no words
            100L
        }

        // Start playing the audio
        mediaPlayer.start()

        // Start highlighting from the beginning of the text
        var highlightStart = 0

        // Handler to schedule highlighting updates
        val handler = Handler()

        // Runnable for highlighting updates
        val highlightRunnable = object : Runnable {
            override fun run() {
                if (highlightStart < spokenText.length) {
                    // Find the end index of the current word
                    var highlightEnd = highlightStart
                    while (highlightEnd < spokenText.length && spokenText[highlightEnd] != ' ') {
                        highlightEnd++
                    }
                    // Apply highlighting to the current word
                    val spannableString = SpannableString(spokenText)
                    spannableString.setSpan(
                        BackgroundColorSpan(Color.YELLOW), // Set background color for highlighting
                        highlightStart,
                        highlightEnd,
                        Spannable.SPAN_INCLUSIVE_EXCLUSIVE
                    )
                    text.setText(spannableString, TextView.BufferType.SPANNABLE) // Update text with highlighting

                    // Move to the next word
                    highlightStart = highlightEnd + 1 // Add 1 for the space between words
                }

                // Schedule the next highlighting update after the calculated delay
                handler.postDelayed(this, highlightDelay)
            }
        }

        // Start the highlighting process
        handler.post(highlightRunnable)

        // Stop the highlighting process when the audio playback completes
        mediaPlayer.setOnCompletionListener {
            handler.removeCallbacks(highlightRunnable)
            // Optionally, you can reset the text view's appearance after highlighting completes

        }
    }
studionexus-lk commented 9 months ago

https://github.com/k2-fsa/sherpa-onnx/assets/105704203/bfd451f5-4a5d-470c-b60d-ac30e797c049

csukuangfj commented 9 months ago

any ideas to implement non fake highlighter

I suggest highlighting a sentence that is being synthesized instead of a word.

The callback is called whenever a sentence is finished synthesizing.

Since the synthesize speed is faster than real-time, i.e., faster than the play speed, you may need to save the samples obtained from the callback in a FIFO.

If the FIFO is not empty, you can deque the samples from the FIFO, play them, and at the same time highlight the sentence corresponding to the current samples that are being played.

You need to track which samples in the FIFO correspond to which sentence.

For piper VITS models, the code to break text into sentences is https://github.com/k2-fsa/sherpa-onnx/blob/763a51486ed1aa58c3808e9438832e6cb3c510d2/sherpa-onnx/csrc/piper-phonemize-lexicon.cc#L232

(You can assume it uses ., ?, ! to break text into sentences).


If you want to highlight at the word level, then you can pass a single word instead of the whole text for synthesizing and you treat each word as a sentence.

studionexus-lk commented 9 months ago

tried those to implement but failed, reserarch on chatgpt and no sollutions. can someone help me with this?

csukuangfj commented 8 months ago

@studionexus-lk

I just learned it is possible to locate where each input word is in the returned wave. It requires changing the onnx export code and the C++ API to return such information.

GeorgeS2019 commented 7 months ago

@csukuangfj

Any update ?

requires changing the onnx export code and the C++ API to return such information.

studionexus-lk commented 3 months ago

@studionexus-lk

I just learned it is possible to locate where each input word is in the returned wave. It requires changing the onnx export code and the C++ API to return such information.

where i can find the export code and c++ api.

studionexus-lk commented 1 month ago

@csukuangfj

Any update ?

requires changing the onnx export code and the C++ API to return such information.

🥲 No

csukuangfj commented 1 month ago

@studionexus-lk I just learned it is possible to locate where each input word is in the returned wave. It requires changing the onnx export code and the C++ API to return such information.

where i can find the export code and c++ api.

You can wait for our next version of TTS models, which can output the duration of each token.

studionexus-lk commented 1 month ago

any updates?

csukuangfj commented 1 month ago

sorry, not yet

studionexus-lk commented 2 weeks ago

@studionexus-lk I just learned it is possible to locate where each input word is in the returned wave. It requires changing the onnx export code and the C++ API to return such information.

where i can find the export code and c++ api.

You can wait for our next version of TTS models, which can output the duration of each token.

umm is this was avaliable now?