huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.36k stars 26.64k forks source link

Getting time offsets of beginning and end of each word in Wav2Vec2 #11307

Open theainerd opened 3 years ago

theainerd commented 3 years ago

🚀 Feature request

Hello I was thinking it would be of great help if I can get the time offsets of start and end of each word .

Motivation

I was going through Google Speech to text documentation and found this feature and thought will be really amazing if i can have something similar here.

Your contribution

I can really use some help in this task and would love to implement something similar.

theainerd commented 3 years ago

@patrickvonplaten @patil-suraj @sgugger

patrickvonplaten commented 3 years ago

This sounds like a nice feature, but I sadly won't have time to work on it - let's see if someone in the community could be interested :-)

theainerd commented 3 years ago

There is something like this which may help : https://github.com/lumaku/espnet/blob/espnet2_ctc_segmentation/espnet2/bin/asr_align.py

I need some help in integrating it to wav2vec2 in hugging face.

Muktan commented 3 years ago

@theainerd are you working on this feature?

MerryOscar commented 3 years ago

I would also really like to see this feature.

@theainerd I'd be happy to help in any way I can although I'm not too familiar with the Wav2Vec transformer.

@patrickvonplaten do you think you could write out a brief outline of what you think the steps required would be?

yushao2 commented 3 years ago

Hi there!

I'm very very new to collaborating on open-source projects as well as on using huggingface/transformers in general therefore I'm not confident I can come up with a solution for this issue -- however I did some poking around with tutorials surrounding Wav2Vec2 and I was thinking of ways on how this might be implemented:

It seems like the Wav2Vec2FeatureExtractor does most of the heavylifting of converting the raw audio array to suitable input values

-> These input values are then fed into the model to obtain the logits (Dimension of the output is observed to be dropped a considerable amount here)

-> after applying argmax to obtain the IDs, these IDs are then fed back into the Wav2Vec2CTCTokenizer decode/batch_decode function to obtain the transcription.

Perhaps information of the sampling rate should be stored within the Tokenizer class such that during decode it's able to make use of this information to determine the timestamp? Or it might be possible to store it within the Wav2Vec2Processor class and have some wrapper functions take care of determining the timestamp and including it during the decode step

A relation of how the input values dimensions are mapped to the output logit's dimensions would be needed for this, which I don't have the expertise at the moment to figure out

CC: @theainerd @MerryOscar @patrickvonplaten

sources I've been referring to -- https://www.kdnuggets.com/2021/03/speech-text-wav2vec.html (I realise this is outdated with the old tokenizer class, which seems to perform feature extraction as well)

https://huggingface.co/blog/fine-tune-wav2vec2-english

krrishdholakia commented 3 years ago

+1 on this, i'd really appreciate timestamped words as well. the datasets like timit, etc. seem to have this info, but i guess that's part of their test set, not an output from the model itself.

krrishdholakia commented 3 years ago

Here's what i've found so far: if speech length is - 480,000 input_values lenth - 480,000 logits length - 1499

this was for a 30s audio file. ` model = Wav2Vec2ForCTC processor = Wav2Vec2Processor

    input_values = processor(speech, return_tensors="pt").input_values
    logits = model(input_values).logits

`

yushao2 commented 3 years ago

Here's what i've found so far: if speech length is - 480,000 input_values lenth - 480,000 logits length - 1499

this was for a 30s audio file. ` model = Wav2Vec2ForCTC processor = Wav2Vec2Processor

    input_values = processor(speech, return_tensors="pt").input_values
    logits = model(input_values).logits

`

Thanks for investigating on this -- while I think it may be possible to just use the ratio and sampling rate to derive the timestamp, what I'm afraid of is that this ratio might just be a "magic number" and might differ if there are variations in the configuration of the Wav2Vec2 model

Current ratio from input_values size to logits seem to be around 320

e.g.: Does the ratio change if the hyperparameters of the model are changed?

Is this ratio constant for varying size of audio? (Experiment with different size WAV clips and check the ratio)

yushao2 commented 3 years ago

Here's what i've found so far: if speech length is - 480,000 input_values lenth - 480,000 logits length - 1499 this was for a 30s audio file. ` model = Wav2Vec2ForCTC processor = Wav2Vec2Processor

    input_values = processor(speech, return_tensors="pt").input_values
    logits = model(input_values).logits

`

Thanks for investigating on this -- while I think it may be possible to just use the ratio and sampling rate to derive the timestamp, what I'm afraid of is that this ratio might just be a "magic number" and might differ if there are variations in the configuration of the Wav2Vec2 model

Current ratio from input_values size to logits seem to be around 320

e.g.: Does the ratio change if the hyperparameters of the model are changed?

Is this ratio constant for varying size of audio? (Experiment with different size WAV clips and check the ratio)

Maybe @patrickvonplaten could shed some light of whether we are going in the right direction about this (if it's not too much trouble) 😓 🙏

krrishdholakia commented 3 years ago

Here's what i've found so far: if speech length is - 480,000 input_values lenth - 480,000 logits length - 1499 this was for a 30s audio file. ` model = Wav2Vec2ForCTC processor = Wav2Vec2Processor

    input_values = processor(speech, return_tensors="pt").input_values
    logits = model(input_values).logits

`

Thanks for investigating on this -- while I think it may be possible to just use the ratio and sampling rate to derive the timestamp, what I'm afraid of is that this ratio might just be a "magic number" and might differ if there are variations in the configuration of the Wav2Vec2 model

Current ratio from input_values size to logits seem to be around 320

e.g.: Does the ratio change if the hyperparameters of the model are changed?

Is this ratio constant for varying size of audio? (Experiment with different size WAV clips and check the ratio)

hey @yushao2, what ratio are you referring to here ? sorry, not too familiar with audio processing

krrishdholakia commented 3 years ago

@patrickvonplaten @yushao2 following up on this

yushao2 commented 3 years ago

@patrickvonplaten @yushao2 following up on this

Hi there! Sorry for not being responsive here.

The ratio here refers to the number you get when you divide the size of input_values to the size of logits

in this case, you mentioned

input_values lenth - 480,000 logits length - 1499

the ratio would be 480000/1499 which is approximately 320

theainerd commented 3 years ago

Hello all,

There is something I have found which may serve as a good starting point. Basically this returns the time offsets and the textual data as well .

https://github.com/lumaku/ctc-segmentation


import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import re

from ctc_segmentation import ctc_segmentation
from ctc_segmentation import CtcSegmentationParameters
from ctc_segmentation import determine_utterance_segments
from ctc_segmentation import prepare_text

# Get the Wav2Vec2 model and the predicted text
test_dataset = load_dataset("common_voice", "hi", split="test")
wer = load_metric("wer")

processor = Wav2Vec2Processor.from_pretrained("theainerd/Wav2Vec2-large-xlsr-hindi")
model = Wav2Vec2ForCTC.from_pretrained("theainerd/Wav2Vec2-large-xlsr-hindi")
model.to("cuda")

chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“]'

resampler = torchaudio.transforms.Resample(48_000, 16_000)

# Preprocessing the datasets.
# We need to read the aduio files as arrays
def speech_file_to_array_fn(batch):
  batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
  speech_array, sampling_rate = torchaudio.load(batch["path"])
  batch["speech"] = resampler(speech_array).squeeze().numpy()
  return batch

test_dataset = test_dataset.map(speech_file_to_array_fn)

input_values = processor(test_dataset["speech"][0], return_tensors="pt").input_values  # Batch size 1
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])

softmax = torch.nn.Softmax(dim = -1)

# apply configuration
config = CtcSegmentationParameters()

with torch.no_grad():
    # Apply ctc layer to obtain log character probabilities
    lpz = softmax(logits)[0].cpu().numpy()

char_dict = {"न": 0, "च": 1, "थ": 2, "ी": 3, "ऐ": 4, "ृ": 5, "ध": 6, "य": 7, "ह": 8, "ऊ": 9, "म": 10, "ण": 11, "ै": 13, "ौ": 14, "ा": 15, "ल": 16, "त": 17, "इ": 18, "ढ़": 19, "ष": 20, "भ": 21, "ग़": 22, "ख": 23, "ड़": 24, "ए": 25, "व": 26, "ु": 27, "ओ": 28, "र": 29, "श": 30, "औ": 31, "ट": 32, "आ": 33, "ो": 34, "ढ": 35, "झ": 36, "ग": 37, "ज़": 38, "अ": 39, "े": 40, "प": 41, "घ": 42, "द": 43, "ई": 44, "फ़": 45, "ब": 46, "ड": 47, "ँ": 48, "छ": 49, "ू": 50, "फ": 51, "ि": 52, "स": 53, "्": 54, "क": 55, "उ": 56, "ठ": 57, "ं": 58, "़": 59, "ज": 60, "क़": 61, "|": 12, "[UNK]": 62, "[PAD]": 63}
char_list = list(char_dict.keys())

# Prepare the text for aligning
ground_truth_mat, utt_begin_indices = prepare_text(config, transcription,char_list)
# Align using CTC segmentation
timings, char_probs, state_list = ctc_segmentation(config, lpz, ground_truth_mat)

# Obtain list of utterances with time intervals and confidence score
segments = determine_utterance_segments(config, utt_begin_indices, char_probs, timings, transcription)
# Sample Output : 0.26 1.73 -0.0154 THE SALE OF THE HOTELS * An example picked up from the ctc_segmentation 

Now if I have the time offsets but how to get this for each and every word rather than the segments. Please don't take this as an absolute solution as I am not sure that this is a good direction to go but still something is better than nothing. Please share your thoughts.

KB-g commented 3 years ago

Hi everyone, here is some sample code which I have created to get the word-level start and end timestamps. It's surely a bit hacky, and I could imagine there being some special cases where it might break, but for the cases I have tried it with it worked great.

from itertools import groupby
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import soundfile as sf

##############
# load model & audio and run audio through model
##############
model_name = 'facebook/wav2vec2-large-960h-lv60-self'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).cuda()

audio_filepath = ''
speech, sample_rate = sf.read(audio_filepath)
input_values = processor(speech, sampling_rate=sample_rate, return_tensors="pt").input_values.cuda()

logits = model(input_values).logits

predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0]).lower()

##############
# this is where the logic starts to get the start and end timestamp for each word
##############
words = [w for w in transcription.split(' ') if len(w) > 0]
predicted_ids = predicted_ids[0].tolist()
duration_sec = input_values.shape[1] / sample_rate

ids_w_time = [(i / len(predicted_ids) * duration_sec, _id) for i, _id in enumerate(predicted_ids)]
# remove entries which are just "padding" (i.e. no characers are recognized)
ids_w_time = [i for i in ids_w_time if i[1] != processor.tokenizer.pad_token_id]
# now split the ids into groups of ids where each group represents a word
split_ids_w_time = [list(group) for k, group
                    in groupby(ids_w_time, lambda x: x[1] == processor.tokenizer.word_delimiter_token_id)
                    if not k]

assert len(split_ids_w_time) == len(words)  # make sure that there are the same number of id-groups as words. Otherwise something is wrong

word_start_times = []
word_end_times = []
for cur_ids_w_time, cur_word in zip(split_ids_w_time, words):
    _times = [_time for _time, _id in cur_ids_w_time]
    word_start_times.append(min(_times))
    word_end_times.append(max(_times))

words, word_start_times, word_end_times
doublex commented 3 years ago

@KB-g Congrats! Is there a chance to also extract the "per word probability"?

doublex commented 3 years ago

@KB-g The assert len() == len() triggers. This audio: assert.zip Testcase:

from itertools import groupby
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import soundfile as sf

model_name = 'DewiBrynJones/wav2vec2-large-xlsr-welsh'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

audio_filepath = '/tmp/assert.wav'
speech, sample_rate = sf.read(audio_filepath)
input_values = processor(speech, sampling_rate=sample_rate, return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0]).lower()

##############
# this is where the logic starts to get the start and end timestamp for each word
##############
words = [w for w in transcription.split(' ') if len(w) > 0]
predicted_ids = predicted_ids[0].tolist()
duration_sec = input_values.shape[1] / sample_rate
ids_w_time = [(i / len(predicted_ids) * duration_sec, _id) for i, _id in enumerate(predicted_ids)]
ids_w_time = [i for i in ids_w_time if i[1] != processor.tokenizer.pad_token_id]
split_ids_w_time = [list(group) for k, group
                    in groupby(ids_w_time, lambda x: x[1] == processor.tokenizer.word_delimiter_token_id)
                    if not k]
# make sure that there are the same number of id-groups as words. Otherwise something is wrong
assert len(split_ids_w_time) == len(words), (len(split_ids_w_time), len(words))
abhirooptalasila commented 2 years ago

@KB-g Congrats! Is there a chance to also extract the "per word probability"?

Hey @KB-g Any success on this?

KB-g commented 2 years ago

Hi @doublex , @abhirooptalasila, I haven't tried to get the per-word probability. If you come up with a solution, it would be great if you could let me know. I'd also be interested in a solution :)

jcsilva commented 2 years ago

Hi @KB-g, @doublex and @abhirooptalasila,

maybe this tutorial helps to find out a way to calculate a "per-word probability". In the function merge_words, the author calculates scores for each word based on tokens probabilities and theirs duration.

patrickvonplaten commented 2 years ago

We need to document the time stamp retrieval a bit better here I think

Ap1075 commented 2 years ago

@KB-g Thanks for the code snippet, really useful. Made a small addition (no_grad) for inference, would help people facing OOM error(s):

from itertools import groupby
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import soundfile as sf

##############
# load model & audio and run audio through model
##############
model_name = 'facebook/wav2vec2-large-960h-lv60-self'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).cuda()

audio_filepath = ''
speech, sample_rate = sf.read(audio_filepath)
input_values = processor(speech, sampling_rate=sample_rate, return_tensors="pt").input_values.cuda()

with torch.no_grad():
    logits = model(input_values).logits

predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0]).lower()

##############
# this is where the logic starts to get the start and end timestamp for each word
##############
words = [w for w in transcription.split(' ') if len(w) > 0]
predicted_ids = predicted_ids[0].tolist()
duration_sec = input_values.shape[1] / sample_rate

ids_w_time = [(i / len(predicted_ids) * duration_sec, _id) for i, _id in enumerate(predicted_ids)]
# remove entries which are just "padding" (i.e. no characers are recognized)
ids_w_time = [i for i in ids_w_time if i[1] != processor.tokenizer.pad_token_id]
# now split the ids into groups of ids where each group represents a word
split_ids_w_time = [list(group) for k, group
                    in groupby(ids_w_time, lambda x: x[1] == processor.tokenizer.word_delimiter_token_id)
                    if not k]

assert len(split_ids_w_time) == len(words)  # make sure that there are the same number of id-groups as words. Otherwise something is wrong

word_start_times = []
word_end_times = []
for cur_ids_w_time, cur_word in zip(split_ids_w_time, words):
    _times = [_time for _time, _id in cur_ids_w_time]
    word_start_times.append(min(_times))
    word_end_times.append(max(_times))

words, word_start_times, word_end_times
samuelbradshaw commented 1 year ago

@Ap1075, thank you for the example you provided above. I'm having a hard time figuring out where/how to pass in transcribed text so it can be aligned with the audio. Is passing in pre-transcribed text possible, or am I misunderstanding how it works?

jmealo commented 1 year ago

I'm trying to get word timing for karaoke I have the lyrics... Would this be possible? 🤔

hegdeadithyak commented 6 months ago

Hi there @patrickvonplaten ,

I'd like to take a look at this issue and see if I can help fix it. Please let me know if it's already assigned to someone or if there's anything specific I should keep in mind while working on it.

Thanks, Adithya Hegde Kota