falibabaei / yamnet_finetun

By incorporating these scripts into the YAMNet code (https://github.com/tensorflow/models/tree/master/research/audioset/yamnet), you can fine-tune the model.
2 stars 2 forks source link

inference #2

Open HeChengHui opened 1 month ago

HeChengHui commented 1 month ago

@falibabaei

i managed to train a model by setting 3 classes and using one hot encoding on the targets.

image above is the structure of the trained model.

when i tried to do the following based on the original visualization, i got the following error: InvalidArgumentError: Input to reshape is a tensor with 78720 values, but the requested shape has 483655680 [Op:Reshape] so how do i load this model and use it?

falibabaei commented 1 month ago

Hi,

I think you have to do some preprocessing on your wav file before giving it to the model as an input. There is a script for inference in the main repository of yamnet model

HeChengHui commented 1 month ago

@falibabaei based on your script, it seems like you have set the input to log_mel_spectrogram using main_ds = main_ds.map(yamnet_frames_model_transfer1). I have also set it to that in my preporcessing

@tf.function
def load_wav_16k_mono(filename):
    """ Load a WAV file, convert it to a float tensor, resample to 16 kHz single-channel audio. """
    file_contents = tf.io.read_file(filename)
    wav, sample_rate = tf.audio.decode_wav(
          file_contents,
          desired_channels=1
          )

    wav = tf.squeeze(wav, axis=-1)
    sample_rate = tf.cast(sample_rate, dtype=tf.int64)
    wav = tfio.audio.resample(wav, rate_in=sample_rate, rate_out=16000)

    return wav, sample_rate
def yamnet_frames_model_transfer1(wav_data):

    waveform_padded = features_lib.pad_waveform(wav_data, yamnet_params)
    log_mel_spectrogram, features = features_lib.waveform_to_log_mel_spectrogram_patches(
        waveform_padded, yamnet_params)
    print(log_mel_spectrogram.shape)

    return log_mel_spectrogram

# Read in the audio.
wav_file_name = '/home/user/Downloads/speech_whistling2.wav'
# wav_data, sr = sf.read(wav_file_name, dtype=np.int16)
wav_data, sr = load_wav_16k_mono(wav_file_name)
# waveform = wav_data / 32768.0
# print(f"waveform: {waveform}")
print(f"wav_data: {wav_data}")
wav_data = yamnet_frames_model_transfer1(wav_data)
print(f"wav_data: {wav_data}")

However, i am still having input shape to reshape wrong. wonder if you can provide an inference script to use or reference?