pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.54k stars 654 forks source link

Bug in the ctc_forced_alignment_api_tutorial #3526

Closed eyalcohen308 closed 1 year ago

eyalcohen308 commented 1 year ago

🐛 Bug description

The merge_words method in the ctc_forced_alignment_api_tutorial contains a bug. If the last word is a single character(for example the token "*" or the word "I"), it ignores the word. For example, when the transcript is "YOU|AND|I"

# Obtain word alignments from token alignments
def merge_words(transcript, [segments](https://docs.python.org/3/library/stdtypes.html#list), separator=" "):
    words = []
    i1, i2, i3 = 0, 0, 0
    while i3 < len(transcript):
        if i3 == len(transcript) - 1 or transcript[i3] == separator:
            if i1 != i2:
                if i3 == len(transcript) - 1:
                    i2 += 1
                if separator == "|":
                    # s is the number of separators (counted as a valid modeling unit) we've seen
                    s = len(words)
                else:
                    s = 0
                segs = [segments](https://docs.python.org/3/library/stdtypes.html#list)[i1 + s : i2 + s]
                word = "".join([[seg.label](https://docs.python.org/3/library/stdtypes.html#str) for seg in segs])
                [score](https://docs.python.org/3/library/functions.html#float) = sum([seg.score](https://docs.python.org/3/library/functions.html#float) * len(seg) for seg in segs) / sum(len(seg) for seg in segs)
                words.append(Segment(word, [segments](https://docs.python.org/3/library/stdtypes.html#list)[i1 + s].start, [segments](https://docs.python.org/3/library/stdtypes.html#list)[i2 + s - 1].end, [score](https://docs.python.org/3/library/functions.html#float)))
            i1 = i2
        else:
            i2 += 1
        i3 += 1
    return words

Input examples:

TRANSCRIPT = "I|HAD|*"
merge_words(TRANSCRIPT,segments, "|")

The output doesn't contain the star token segment '*':

[I  (1.00): [   30,    31), HAD (0.97): [   36,    41)]
TRANSCRIPT = "YOU|AND|I"
merge_words(TRANSCRIPT,segments, "|")

The output doesn't contain the word segment 'I':

[YOU    (0.25): [   30,    61), AND (0.33): [   97,   111)]

My solution:

def merge_words(
    transcript: str,
    segments: List[Segment],
    separator: Optional[str] = " ",
    include_separator: Optional[bool] = False,
) -> List[Segment]:
    word_segments = []
    separator_idx_lst = [
        idx for idx, char in enumerate(transcript) if char == separator
    ]
    separator_idx_lst.append(
        len(transcript)
    )  # append the end of the transcript for the last word

    start_idx = 0  # start index for the first word

    for end_idx in separator_idx_lst:
        if start_idx != end_idx:  # check if the word is not empty
            chars_segs_in_word = segments[start_idx:end_idx]
            word_str = "".join([segment.label for segment in chars_segs_in_word])
            total_len = sum(len(segment) for segment in chars_segs_in_word)
            avg_score = (sum(seg.score * len(seg) for seg in chars_segs_in_word) / total_len)
            word_segments.append(
                Segment(
                    word_str,
                    chars_segs_in_word[0].start,
                    chars_segs_in_word[-1].end,
                    avg_score,
                )
            )
        # include the separator segment if the flag is set and it's not the last character
        if include_separator and end_idx != len(transcript):
            word_segments.append(segments[end_idx])
        start_idx = end_idx + 1  # update the start index for the next word

    return word_segments

Versions

Collecting environment information... PyTorch version: 2.1.0.dev20230719 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64) GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0 Clang version: Could not collect CMake version: version 3.26.0 Libc version: glibc-2.35

Python version: 3.11.4 (main, Jul 5 2023, 13:45:01) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-5.15.0-53-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 11.7.99 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A40 GPU 1: NVIDIA A40 GPU 2: NVIDIA A40 GPU 3: NVIDIA A40 GPU 4: NVIDIA A40 GPU 5: NVIDIA A40 GPU 6: NVIDIA A40 GPU 7: NVIDIA A40

Nvidia driver version: 520.61.05 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 128 On-line CPU(s) list: 0-127 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Gold 6338 CPU @ 2.00GHz CPU family: 6 Model: 106 Thread(s) per core: 2 Core(s) per socket: 32 Socket(s): 2 Stepping: 6 CPU max MHz: 3200.0000 CPU min MHz: 800.0000 BogoMIPS: 4000.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid fsrm md_clear pconfig flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 3 MiB (64 instances) L1i cache: 2 MiB (64 instances) L2 cache: 80 MiB (64 instances) L3 cache: 96 MiB (2 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-31,64-95 NUMA node1 CPU(s): 32-63,96-127 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable Vulnerability Retbleed: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] numpy==1.24.3 [pip3] torch==2.1.0.dev20230719 [pip3] torchaudio==2.1.0.dev20230719 [pip3] torchvision==0.16.0.dev20230719 [pip3] triton==2.1.0 [conda] blas 1.0 mkl
[conda] brotlipy 0.7.0 py311h9bf148f_1002 pytorch-nightly [conda] cffi 1.15.1 py311h9bf148f_3 pytorch-nightly [conda] cryptography 38.0.4 py311h46ebde7_0 pytorch-nightly [conda] filelock 3.9.0 py311_0 pytorch-nightly [conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py311h9bf148f_0 pytorch-nightly [conda] mkl_fft 1.3.1 py311hc796f24_0 pytorch-nightly [conda] mkl_random 1.2.2 py311hbba84a0_0 pytorch-nightly [conda] mpmath 1.2.1 py311_0 pytorch-nightly [conda] numpy 1.24.3 py311hc206e33_0
[conda] numpy-base 1.24.3 py311hfd5febd_0
[conda] pillow 9.3.0 py311h3fd9d12_2 pytorch-nightly [conda] pysocks 1.7.1 py311_0 pytorch-nightly [conda] pytorch 2.1.0.dev20230719 py3.11_cuda11.8_cudnn8.7.0_0 pytorch-nightly [conda] pytorch-cuda 11.8 h7e8668a_5 pytorch-nightly [conda] pytorch-mutex 1.0 cuda pytorch-nightly [conda] requests 2.28.1 py311_0 pytorch-nightly [conda] torchaudio 2.1.0.dev20230719 py311_cu118 pytorch-nightly [conda] torchtriton 2.1.0+9e3e10c5ed py311 pytorch-nightly [conda] torchvision 0.16.0.dev20230719 py311_cu118 pytorch-nightly [conda] urllib3 1.26.14 py311_0 pytorch-nightly

mthrok commented 1 year ago

Hi @eyalcohen308

Thanks for checking out the code. I think this is somewhat related to the fact that conventional wav2vec2 model contains word boundary "|", while the new MMS model does not, and this notebook was originally written for MMS but later changed to use the original wav2vec2 model, and along the way, the helper function got somewhat messed up.

I am in the process of re-writing the tutorial and am going to eliminate the word boundary special case in the function. I will come back soon.

mthrok commented 1 year ago

Addressed by #3542.

I introduced merge_token API which is unit-tested and should not have the issue.