Open RiuHDuo opened 1 year ago
"""Test Script for DTMST Model."""
# system imports
import os
import sys
import argparse
import pretty_midi
# additional imports
import torch
import librosa
# from dtmst.model import DTMSTModel
import note_seq
# pylint: disable=C0413
from dtmst import DTMSTModel, data, infer_utils
import pretty_midi
def create_note(start, end, pitch, velocity=100, instrument_name='Acoustic Grand Piano'):
"""Create a PrettyMIDI Note object."""
# Define the note
note = pretty_midi.Note(velocity=velocity,
pitch=pitch,
start=start,
end=end)
return note
def save_midi(sequence, file_path):
"""
Saves a NoteSequence as a MIDI file.
Parameters:
sequence (note_seq.NoteSequence): The NoteSequence to save
file_path (str): Path to the output MIDI file
"""
# Write the MIDI file
note_seq.sequence_proto_to_midi_file(sequence, file_path)
print(f"MIDI file saved to {file_path}")
fn_checkpoint = 'dtmst/checkpoints/checkpoint_sinsy_logmel.pth'
# Load the model
checkpoint = torch.load(fn_checkpoint, map_location=torch.device('cpu'))
hparams = checkpoint['hparams']
model = DTMSTModel(hparams)
model.load_state_dict(checkpoint['model_state'])
model.eval()
def transcribe_dtmst(fn_audio: str):
"""Transcribe audio file to MIDI using DTMST."""
x, _ = librosa.load(fn_audio, sr=hparams['sample_rate'])
x = librosa.util.normalize(x)
mel = data.wav_to_mel(x, hparams)
example = {'spectrogram': torch.from_numpy(mel).float().unsqueeze(2), 'audio': x}
midi_data = infer_utils.transcribe(example, model, hparams)
return midi_data
def main(audio_file: str):
"""Main function to run the test script."""
midi_data = transcribe_dtmst(audio_file)
# Process and save the MIDI data
# This part depends on how you want to handle the MIDI data
# For example, you might want to save it as a MIDI file
save_midi(midi_data, 'output.mid')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Test DTMST Model")
parser.add_argument('audio_file', type=str, help='Path to the audio file')
args = parser.parse_args()
main(args.audio_file)
Missing transcribe.py and trian.py? I can not find the transcribe.py or trian.py