spotify / basic-pitch

A lightweight yet powerful audio-to-MIDI converter with pitch bend detection
https://basicpitch.io
Apache License 2.0
3.44k stars 272 forks source link

How to get trained model in keras format instead of saved model? #64

Open DamRsn opened 1 year ago

DamRsn commented 1 year ago

Thanks for this amazing work!

I'm currently working on integrating basic-pitch in an audio plugin. For that, I would need to have the trained basic pitch model as a Keras model in order to convert it into the format that I'm going to use in C++ for inference.

Unfortunately, the ICASSP2022 model is saved as a tensorflow saved_model, even though it is originally a keras model that can be built with the model() function from model.py.

I tried to transfer the weights from the trained model to an untrained keras model but I did not find a straightforward way to do it.

I could gather all the layers' weights manually from the saved_model and apply them one by one to an untrained keras model created with model(), but before getting there: do you have a trained keras model that you could share? Or do you know a way to get it from the trained tensorflow saved_model?

Also, similar to #62 , is there a more recent model I could use, perhaps the one used in the basic-pitch-ts repo for https://basicpitch.spotify.com/ ?

Thanks for your help!

achimmihca commented 1 year ago

I successfully converted the saved_model format of basic pitch to onnx format using tf2onnx.

You could try to further convert this onnx model to keras using onnx2keras.

DamRsn commented 1 year ago

On my side, I managed to apply the weights of the tfjs model (model used in the basic-pitch website) to the keras model (created with model()). I did it by manually exporting the weights of each layer to .npy using the Netron App and then applying them using .setWeights() for each layer.

Converting to ONNX definitely helps, but I don't think you'll get the same structure for the keras model as the one you get when you create the keras model using model(), it will be much more messy. Plus you need a way to separate the CQT computation part of the network and the actual CNN.

All those issues would be solved if the model was saved directly with the keras API after training, but I don't know if it is that easy to do once the model is trained and saved as a .saved_model.

dpl123 commented 1 year ago

Agree, amazing for sure! @DamRsn very cool plugin too!!

Another vote for getting some kind of more usable model format. Any assistance would be appreciated with this