qiuqiangkong / piano_transcription_inference

332 stars 56 forks source link

Script does not respect loading custom checkpoint for Regress_onset_offset_frame_velocity_CRNN #17

Open almostimplemented opened 1 year ago

almostimplemented commented 1 year ago

Hi @qiuqiangkong ,

Thanks for this tool.

I want to raise a bug and offer a solution.

In my code I do the following:

transcriber = hr.PianoTranscription(
    model_type='Regress_onset_offset_frame_velocity_CRNN',
    checkpoint_path='/homes/ace01/transcription_checkpoints/checkpoint_aug.pth',
    device='cuda'
)

But then the code downloads the published model anyways and overwrites it:

Checkpoint path: /homes/ace01/transcription_checkpoints/checkpoint_aug.pth
Total size: ~165 MB
--2023-05-01 02:07:51--  https://zenodo.org/record/4034264/files/CRNN_note_F1%3D0.9677_pedal_F1%3D0.9186.pth?download=1
Resolving zenodo.org (zenodo.org)... 188.185.124.72
Connecting to zenodo.org (zenodo.org)|188.185.124.72|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 171966578 (164M) [application/octet-stream]
Saving to: ‘/homes/ace01/transcription_checkpoints/checkpoint_aug.pth’

100%[========================================================================================================================================>] 171,966,578 28.6MB/s   in 6.0s

2023-05-01 02:07:58 (27.4 MB/s) - ‘/homes/ace01/transcription_checkpoints/checkpoint_aug.pth’ saved [171966578/171966578]

It's because this logic here is quite brittle:

        if not os.path.exists(checkpoint_path) or os.path.getsize(checkpoint_path) < 1.6e8:
            create_folder(os.path.dirname(checkpoint_path))
            print('Total size: ~165 MB')
            zenodo_path = 'https://zenodo.org/record/4034264/files/CRNN_note_F1%3D0.9677_pedal_F1%3D0.9186.pth?download=1'
            os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path))

At the very least, the code should only execute this if the argument checkpoint_path was None to begin with.

It's good that I have copies of the checkpoint, but it would be a shame for a user to point this script at their checkpoint only to have your code overwrite the weights!

Shinning-Zhou commented 9 months ago

Hi, I have met a similar problem and read your comments, but still didn't solve the problem,can you help me? In my code I do the following:

from piano_transcription_inference import PianoTranscription, sample_rate, load_audio

# Load audio
(audio, _) = load_audio('一生等你钢琴.mp3', sr=sample_rate, mono=True)

# Transcriptor
transcriptor = PianoTranscription(device='cpu')    # 'cuda' | 'cpu'

# Transcribe and write out to MIDI file
transcribed_dict = transcriptor.transcribe(audio, 'cut_liszt.mid')

And I meet:

image

I wonder what's wrong with the checkpoint. TAT