jongwook / onsets-and-frames

A Pytorch implementation of Onsets and Frames (Hawthorne 2018)
MIT License
208 stars 65 forks source link

Upload pretrained model to run inference #17

Open greenbech opened 4 years ago

greenbech commented 4 years ago

It would be great if anyone could upload a pretrained model so that we could try this model/project without needing to train the model. It is quite a big commitment to wait a week for training (as mentioned in #10 ) if you primarily just want to check out the performance on some .wav files.

And I would also like to say this repo is very well written and educational. Thanks!

jongwook commented 4 years ago

Hi, please try this one, trained for 500,000 iterations on the MAESTRO dataset.

I haven't touched the model in a while, but torch.load('model-500000.pt’) should be able to load the PyTorch model.

greenbech commented 4 years ago

The provided file works great, thanks a lot! I didn't need to use torch.load('model-500000.pt’) since both evaluate.py and transcribe.py has the model file as an argument.

However, after I first got this error message when trying the run the scripts:

```bash Traceback (most recent call last): File "transcribe.py", line 101, in transcribe_file(**vars(parser.parse_args())) File "transcribe.py", line 74, in transcribe_file predictions = transcribe(model, audio) File "transcribe.py", line 53, in transcribe onset_pred, offset_pred, _, frame_pred, velocity_pred = model(mel) File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__ result = self.forward(*input, **kwargs) File "/Users/greenbech/git/onsets-and-frames/onsets_and_frames/transcriber.py", line 87, in forward onset_pred = self.onset_stack(mel) File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__ result = self.forward(*input, **kwargs) File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/container.py", line 100, in forward input = module(input) File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__ result = self.forward(*input, **kwargs) File "/Users/greenbech/git/onsets-and-frames/onsets_and_frames/lstm.py", line 29, in forward output[:, start:end, :], (h, c) = self.rnn(x[:, start:end, :], (h, c)) File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__ result = self.forward(*input, **kwargs) File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 558, in forward result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers, File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 576, in __getattr__ type(self).__name__, name)) AttributeError: 'LSTM' object has no attribute '_flat_weights' ```

Downgrading from 1.4.0 to torch==1.2.0 fixed it for me.

It is also quite cumbersome to resample to the audio file to 16kHz before hand, so I added this locally to transcribe.py:

def float_samples_to_int16(y):
  """Convert floating-point numpy array of audio samples to int16."""
  # From https://github.com/tensorflow/magenta/blob/671501934ff6783a7912cc3e0e628fd0ea2dc609/magenta/music/audio_io.py#L48
  if not issubclass(y.dtype.type, np.floating):
    raise ValueError('input samples not floating-point')
  return (y * np.iinfo(np.int16).max).astype(np.int16)

def load_and_process_audio(flac_path, sequence_length, device):

    random = np.random.RandomState(seed=42)

    audio, sr = librosa.load(flac_path, sr=SAMPLE_RATE)
    audio = float_samples_to_int16(audio)

    assert sr == SAMPLE_RATE
    assert audio.dtype == 'int16'
    ...

There might be elegant ways of doing this, but I was not able to convert to uint16 with librosa or resample with soundfile.read.

I also think the the model you provided should be available in the README for others to try out without going to this issue. I was thinking either directly in ./data/pretrained since this it the easiest setup but increases the repo size unnecessarily or with the drive url you provided.

Would you mind a PR with this?

jongwook commented 4 years ago

Yeah! I'll need some housekeeping to make the checkpoint work cross-version. PR is welcome! Thanks :D