m2lines / gz21_ocean_momentum

Stochastic-Deep Learning Parameterization of Ocean Momentum Forcing
MIT License
5 stars 0 forks source link

Model loading #118

Open Etienne-Meunier opened 4 months ago

Etienne-Meunier commented 4 months ago

Hi,

Thank you for your great work ! I am trying to load your model to make inference in python on new fields using the weights trained for the paper although I have a few errors.

Trying to usesrc/gz21_ocean_momentum/__init__.py however It cannot find the file final_transformation.pth.

Similarly if I run :

MODEL_RUN_ID = "dc74cea68a7f4c7e98f9228649a97135"
client = mlflow.tracking.MlflowClient()
model_file = client.download_artifacts(MODEL_RUN_ID, "models/trained_model.pth")

I get :

MlflowException: Run 'dc74cea68a7f4c7e98f9228649a97135' not found

Thanks !

raehik commented 3 months ago

That snippet is trying to load model weights from an MLflow run dc74... that doesn't exist on your local machine.

I don't believe I was using __init__.py in my later changes. Try using cli/infer.py, as in the Predicting using the trained model section of the readme. (I think I kept it around because it was used in some Jupyter notebooks which I wasn't able to refactor before leaving.)

There's a good chance you would need to do more work or edit some of the code if you're trying to infer forcings for different fields. I'm not active on this project anymore, but if you can add a bit more detail on what you're trying to do, myself or someone else might be able to help further.