huggingface / speechbox

Apache License 2.0
342 stars 33 forks source link

Restore punctuation for audios no 16k #4

Closed bofenghuang closed 1 year ago

bofenghuang commented 1 year ago

Hi @patrickvonplaten 👋,

Thanks for this project!

I'm thinking we should have a possible audio resampling since WhisperFeatureExtractor doesn't do it inside.

Below is an updated example. But it might be better to have it inside PunctuationRestorer to make it an out-of-box solution. What's your opinion? Willing to make a PR if necessary :)

import string
import re
from datasets import load_dataset
import librosa
from speechbox import PunctuationRestorer

streamed_dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", split="validation", streaming=True)

# get first sample
sample = next(iter(streamed_dataset))

# print out normalized transcript
print(sample["sentence"])
# => "It is from Westport, above the villages of Murrisk and Lecanvey."
sentence = re.sub(rf"[{re.escape(string.punctuation)}]", "", sample["sentence"]).lower()
print(sentence)
# => "it is from westport above the villages of murrisk and lecanvey"

# load the restoring class
restorer = PunctuationRestorer.from_pretrained("openai/whisper-tiny.en")
restorer.to("cuda")

# resample audio if necessary
model_sample_rate = restorer.processor.feature_extractor.sampling_rate
if sample["audio"]["sampling_rate"] != model_sample_rate:
    sample["audio"]["array"] = librosa.resample(
        sample["audio"]["array"], orig_sr=sample["audio"]["sampling_rate"], target_sr=model_sample_rate, res_type="kaiser_best"
    )

restored_text, log_probs = restorer(sample["audio"]["array"], sentence, sampling_rate=model_sample_rate, num_beams=1)

print("Restored text:\n", restored_text)
# Restored text:
# It is from Westport above the villages of MURRISK and LECANVEY.
patrickvonplaten commented 1 year ago

Good idea! Happy to include it directly in the __call__ method of the PunctuationRestorer :-)

Think both librosa, scipy (https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.resample.html) or torchaudio as an optional dependency as explained here: https://github.com/huggingface/speechbox/blob/main/CONTRIBUTING.md#philosophy

would make a lot of sense (checking the sampling rate exactly like you're doing in the example above :-))

I would slightly tend to scipy as it's pretty lightweight.

Would you like to open a PR for it? :-)

bofenghuang commented 1 year ago

Sure. Thanks for the hints :)