csukuangfj / kaldifeat

Kaldi-compatible online & offline feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd - Provide C++ & Python API
https://csukuangfj.github.io/kaldifeat
Other
186 stars 35 forks source link

Support computing features for whisper #82

Closed csukuangfj closed 10 months ago

csukuangfj commented 10 months ago

Usage


import torchaudio
import torch
import kaldifeat

def compute_features(filename: str) -> torch.Tensor:
    """
    Args:
      filename:
        Path to an audio file.
    Returns:
      Return a 3-D float32 tensor of shape (1, 80, 3000) containing the features.
    """
    wave, sample_rate = torchaudio.load(filename)
    audio = wave[0].contiguous()  # only use the first channel
    if sample_rate != 16000:
        audio = torchaudio.functional.resample(
            audio, orig_freq=sample_rate, new_freq=16000
        )

    opts = kaldifeat.WhisperFbankOptions(device="cpu")
    whisper_fbank = kaldifeat.WhisperFbank(opts)
    features = whisper_fbank(audio)  # [num_frames, 80]

    # we need to pad it to [3000, 80]
    pad = 3000 - features.shape[0]
    if pad > 0:
        features = torch.nn.functional.pad(features, (0, 0, 0, pad), "constant", 0)
    else:
        features = features[:3000]

    features = features.t()
    # now features is [80, 3000]
    return features.unsqueeze(0)  # [1, 80, 3000]