Klangio / dtmst

Dual Task Monophonic Singing Transcription (DTMST) https://www.aes.org/e-lib/browse.cfm?elib=22025
Apache License 2.0
4 stars 0 forks source link

Missing transcribe.py and trian.py? #1

Open RiuHDuo opened 1 year ago

RiuHDuo commented 1 year ago

Missing transcribe.py and trian.py? I can not find the transcribe.py or trian.py

SutirthaChakraborty commented 1 year ago
"""Test Script for DTMST Model."""

# system imports
import os
import sys
import argparse
import pretty_midi
# additional imports
import torch
import librosa
# from dtmst.model import DTMSTModel
import note_seq 

# pylint: disable=C0413
from dtmst import DTMSTModel, data, infer_utils

import pretty_midi

def create_note(start, end, pitch, velocity=100, instrument_name='Acoustic Grand Piano'):
    """Create a PrettyMIDI Note object."""
    # Define the note
    note = pretty_midi.Note(velocity=velocity,
                            pitch=pitch,
                            start=start,
                            end=end)
    return note

def save_midi(sequence, file_path):
    """
    Saves a NoteSequence as a MIDI file.

    Parameters:
    sequence (note_seq.NoteSequence): The NoteSequence to save
    file_path (str): Path to the output MIDI file
    """
    # Write the MIDI file
    note_seq.sequence_proto_to_midi_file(sequence, file_path)
    print(f"MIDI file saved to {file_path}")

fn_checkpoint = 'dtmst/checkpoints/checkpoint_sinsy_logmel.pth'

# Load the model
checkpoint = torch.load(fn_checkpoint, map_location=torch.device('cpu'))
hparams = checkpoint['hparams']
model = DTMSTModel(hparams)
model.load_state_dict(checkpoint['model_state'])
model.eval()

def transcribe_dtmst(fn_audio: str):
    """Transcribe audio file to MIDI using DTMST."""
    x, _ = librosa.load(fn_audio, sr=hparams['sample_rate'])
    x = librosa.util.normalize(x)

    mel = data.wav_to_mel(x, hparams)

    example = {'spectrogram': torch.from_numpy(mel).float().unsqueeze(2), 'audio': x}
    midi_data = infer_utils.transcribe(example, model, hparams)
    return midi_data

def main(audio_file: str):
    """Main function to run the test script."""
    midi_data = transcribe_dtmst(audio_file)
    # Process and save the MIDI data
    # This part depends on how you want to handle the MIDI data
    # For example, you might want to save it as a MIDI file
    save_midi(midi_data, 'output.mid')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Test DTMST Model")
    parser.add_argument('audio_file', type=str, help='Path to the audio file')
    args = parser.parse_args()

    main(args.audio_file)