Open HeChengHui opened 1 month ago
Hi,
I think you have to do some preprocessing on your wav file before giving it to the model as an input. There is a script for inference in the main repository of yamnet model
@falibabaei
based on your script, it seems like you have set the input to log_mel_spectrogram using main_ds = main_ds.map(yamnet_frames_model_transfer1)
. I have also set it to that in my preporcessing
@tf.function
def load_wav_16k_mono(filename):
""" Load a WAV file, convert it to a float tensor, resample to 16 kHz single-channel audio. """
file_contents = tf.io.read_file(filename)
wav, sample_rate = tf.audio.decode_wav(
file_contents,
desired_channels=1
)
wav = tf.squeeze(wav, axis=-1)
sample_rate = tf.cast(sample_rate, dtype=tf.int64)
wav = tfio.audio.resample(wav, rate_in=sample_rate, rate_out=16000)
return wav, sample_rate
def yamnet_frames_model_transfer1(wav_data):
waveform_padded = features_lib.pad_waveform(wav_data, yamnet_params)
log_mel_spectrogram, features = features_lib.waveform_to_log_mel_spectrogram_patches(
waveform_padded, yamnet_params)
print(log_mel_spectrogram.shape)
return log_mel_spectrogram
# Read in the audio.
wav_file_name = '/home/user/Downloads/speech_whistling2.wav'
# wav_data, sr = sf.read(wav_file_name, dtype=np.int16)
wav_data, sr = load_wav_16k_mono(wav_file_name)
# waveform = wav_data / 32768.0
# print(f"waveform: {waveform}")
print(f"wav_data: {wav_data}")
wav_data = yamnet_frames_model_transfer1(wav_data)
print(f"wav_data: {wav_data}")
However, i am still having input shape to reshape wrong. wonder if you can provide an inference script to use or reference?
@falibabaei
i managed to train a model by setting 3 classes and using one hot encoding on the targets.
above is the structure of the trained model.
when i tried to do the following based on the original visualization, i got the following error:
InvalidArgumentError: Input to reshape is a tensor with 78720 values, but the requested shape has 483655680 [Op:Reshape]
so how do i load this model and use it?