kkoutini / PaSST

Efficient Training of Audio Transformers with Patchout
Apache License 2.0
287 stars 48 forks source link

RuntimeError: stft requires the return_complex parameter be given for real inputs #38

Open loukasilias opened 7 months ago

loukasilias commented 7 months ago

Hello! I am using the following code:

from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.
print(model.net) # the transformer network.

# example inference
model.eval()
model = model.cuda()
with torch.no_grad():
    # audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k
    # example audio_wave of batch=3 and 10 seconds
    audio = torch.ones((3, 32000 * 10))*0.5
    audio_wave = audio.cuda()
    logits=model(audio_wave) 

I am getting the following error:

RuntimeError: stft requires the return_complex parameter be given for real inputs, and will further require that return_complex=True in a future PyTorch release.

How can I solve this issue please? Thank you!

kkoutini commented 7 months ago

Hi! Are you using the latest release ? This issue should have been fixed in https://github.com/kkoutini/passt_hear21/commit/dce83183674e559162b49924d666c0a916dc967a

try uninstalling your current passt_hear21 package and reinstalling:

pip install -e 'git+https://github.com/kkoutini/passt_hear21@0.0.25#egg=hear21passt' 
loukasilias commented 7 months ago

Thank you! It has been solved. I have some additional questions.

I have several audio files of variable duration (40 secs, 1 min, etc.). Is it possible to use your library? I am using the following:

import librosa

path = '/kaggle/input/audio-files/audio_files/S004.wav'
x , sr = librosa.load(path, sr=32000)

What can I do next? Thank you again for your help!

kkoutini commented 7 months ago

Hi the model is compatible with the HEAR API. Here is an example of the base model:

from hear21passt.base import load_model, get_scene_embeddings, get_timestamp_embeddings

model = load_model().cuda()
path = '/kaggle/input/audio-files/audio_files/S004.wav'
audio, sr = librosa.load(path, sr=32000)

embed, time_stamps = get_timestamp_embeddings(audio, model)
print(embed.shape)
embed = get_scene_embeddings(audio, model)
print(embed.shape)

If you need more control, take a look here where these methods are implemented.