YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.06k stars 203 forks source link

AssertionError: choose a window size 400 that is [2, 1] #133

Open GrafKnusprig opened 1 week ago

GrafKnusprig commented 1 week ago

I try to use the feature extractor on my audiofiles. My audio files are all 16000Hz and 5 seconds long. The waveform.shape[1] is 80000

input_values = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt").input_values

I get the error: AssertionError: choose a window size 400 that is [2, 1] and I don't really know what to do with it.

Here is the whole thing:


def preprocess_function(examples):
    audio_files = examples['file_path']
    inputs = {'input_values': []}
    for audio_file in tqdm(audio_files, desc="Preprocessing dataset"):
        waveform, sample_rate = torchaudio.load(audio_file)
        # Ensure sample rate is 16000 Hz
        assert sample_rate == 16000, f"Expected sample rate of 16000 Hz, but got {sample_rate} Hz"
        # Assuming all audio files are 5 seconds long
        max_len = 16000 * 5  # 5 seconds at 16000 Hz
        # Pad or truncate to the maximum length
        print(waveform.shape[1])
        if waveform.shape[1] > max_len:
            waveform = waveform[:, :max_len]
        else:
            waveform = torch.nn.functional.pad(waveform, (0, max_len - waveform.shape[1]), "constant", 0)
        input_values = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt").input_values
        inputs['input_values'].append(input_values.squeeze(0))
    return inputs

processed_dataset = dataset_dict.map(preprocess_function, batched=True, remove_columns=['file_path'])```
YuanGongND commented 2 days ago

16kHz 5second should work, our ESC-50 recipe is in this setting. Which line report this error? is your audio monochannel or multi-channel? check shape of waveform.

-Yuan

GrafKnusprig commented 1 day ago

First of all, thanks for the answer!

My waveform looks like this:

Waveform shape: torch.Size([1, 80000])
Waveform dtype: torch.float32
Number of channels: 1

80000 because of the 16000Hz and the 5 seconds.

And the error happens in:

Cell In[14], line 86, in preprocess_function(examples)
     79     # print(f"Waveform max: {waveform.max()}")
     80     # print(f"Waveform min: {waveform.min()}")
     81     # print(f"Waveform mean: {waveform.mean()}")
     82     # print(f"Waveform std: {waveform.std()}")
     83     # printing the number of channels in the waveform
     84     print(f"Number of channels: {waveform.shape[0]}")
---> 86     input_values = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt").input_values
     87     inputs['input_values'].append(input_values.squeeze(0))
     88 return inputs

File D:\GitLab\ss24-aai-lab\.venv\Lib\site-packages\transformers\models\audio_spectrogram_transformer\feature_extraction_audio_spectrogram_transformer.py:219, in ASTFeatureExtractor.__call__(self, raw_speech, sampling_rate, return_tensors, **kwargs)
    216     raw_speech = [raw_speech]
    218 # extract fbank features and pad/truncate to max_length
--> 219 features = [self._extract_fbank_features(waveform, max_length=self.max_length) for waveform in raw_speech]
    221 # convert into BatchFeature
    222 padded_inputs = BatchFeature({"input_values": features})

File D:\GitLab\ss24-aai-lab\.venv\Lib\site-packages\transformers\models\audio_spectrogram_transformer\feature_extraction_audio_spectrogram_transformer.py:219, in <listcomp>(.0)
    216     raw_speech = [raw_speech]
    218 # extract fbank features and pad/truncate to max_length
--> 219 features = [self._extract_fbank_features(waveform, max_length=self.max_length) for waveform in raw_speech]
    221 # convert into BatchFeature
    222 padded_inputs = BatchFeature({"input_values": features})

File D:\GitLab\ss24-aai-lab\.venv\Lib\site-packages\transformers\models\audio_spectrogram_transformer\feature_extraction_audio_spectrogram_transformer.py:119, in ASTFeatureExtractor._extract_fbank_features(self, waveform, max_length)
    117 if is_speech_available():
    118     waveform = torch.from_numpy(waveform).unsqueeze(0)
--> 119     fbank = ta_kaldi.fbank(
    120         waveform,
    121         sample_frequency=self.sampling_rate,
    122         window_type="hanning",
    123         num_mel_bins=self.num_mel_bins,
    124     )
    125 else:
    126     waveform = np.squeeze(waveform)

File D:\GitLab\ss24-aai-lab\.venv\Lib\site-packages\torchaudio\compliance\kaldi.py:591, in fbank(waveform, blackman_coeff, channel, dither, energy_floor, frame_length, frame_shift, high_freq, htk_compat, low_freq, min_duration, num_mel_bins, preemphasis_coefficient, raw_energy, remove_dc_offset, round_to_power_of_two, sample_frequency, snip_edges, subtract_mean, use_energy, use_log_fbank, use_power, vtln_high, vtln_low, vtln_warp, window_type)
    542 r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
    543 compute-fbank-feats.
    544 
   (...)
    587     where m is calculated in _get_strided
    588 """
    589 device, dtype = waveform.device, waveform.dtype
--> 591 waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
    592     waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
    593 )
    595 if len(waveform) < min_duration * sample_frequency:
    596     # signal is too short
    597     return torch.empty(0, device=device, dtype=dtype)

File D:\GitLab\ss24-aai-lab\.venv\Lib\site-packages\torchaudio\compliance\kaldi.py:142, in _get_waveform_and_window_properties(waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient)
    139 window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
    140 padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
--> 142 assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
    143     window_size, len(waveform)
    144 )
    145 assert 0 < window_shift, "`window_shift` must be greater than 0"
    146 assert padded_window_size % 2 == 0, (
    147     "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
    148 )

AssertionError: choose a window size 400 that is [2, 1]

I know that's a lot to ask, but do you have any ideas about what could be wrong? i'm lost.

Thanks a lot.

UPDATE: If I run it with stereo files I get the error: AssertionError: choose a window size 400 that is [2, 2]

do i use the wrong feature extractor?

# Load the model and feature extractor
model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
model = ASTForAudioClassification.from_pretrained(model_name)
feature_extractor = ASTFeatureExtractor.from_pretrained(model_name)