phanngoc / english-pronunciation

0 stars 0 forks source link

Research for change into Wav2Vec2ForCTC #1

Open phanngoc opened 2 months ago

phanngoc commented 2 months ago
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import soundfile as sf
import torch
from jiwer import wer

librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h")
def map_to_pred(batch):
    input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
    with torch.no_grad():
        logits = model(input_values.to("cuda")).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    batch["transcription"] = transcription
    return batch

result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])

print("WER:", wer(result["text"], result["transcription"]))
phanngoc commented 2 months ago
from jiwer import wer, compute_measures
import re

# Example reference and hypothesis sentences
reference = "hello world"
hypothesis = "helo world"

# Compute word error rate and get the differences
measures = compute_measures(reference, hypothesis)
diffs = measures['substitutions'] + measures['insertions'] + measures['deletions']

def highlight_text(reference, hypothesis, diffs):
    result = []
    ref_idx = 0
    hyp_idx = 0
    while ref_idx < len(reference) or hyp_idx < len(hypothesis):
        if ref_idx < len(reference) and (ref_idx, ref_idx) in diffs:
            result.append(f"\033[31m{reference[ref_idx]}\033[0m")  # Red for wrong character
            ref_idx += 1
        elif hyp_idx < len(hypothesis) and (hyp_idx, hyp_idx) in diffs:
            result.append(f"\033[32m{hypothesis[hyp_idx]}\033[0m")  # Green for correct character
            hyp_idx += 1
        else:
            result.append(reference[ref_idx] if ref_idx < len(reference) else hypothesis[hyp_idx])
            ref_idx += 1
            hyp_idx += 1
    return ''.join(result)

highlighted_text = highlight_text(reference, hypothesis, diffs)
print(highlighted_text)