mllam / neural-lam

Research Software for Neural Weather Prediction for Limited Area Modeling
https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ
MIT License
122 stars 48 forks source link

Add support for mlflow #77

Open khintz opened 1 month ago

khintz commented 1 month ago

Describe your changes

Add support for mlflow logger by utilising pytorch_lightning.loggers The native wandb module is replaced with pytorch_lightning wandb logger and introducing pytorch_lightning mlflow logger. https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loggers/logger.py

This will allow people to choose between wandb and mlflow.

Builds upon https://github.com/mllam/neural-lam/pull/66 although this is not strictly necessary for this change, but I am working with this feature to work with our dataset.

Issue Link

Closes https://github.com/mllam/neural-lam/issues/76

Type of change

Checklist before requesting a review

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

Author checklist after completed review

Checklist for assignee

khintz commented 1 month ago

WIP, mlflow logger still not working, but got wandb working with the pytorch_lightning wandb logger. This is dependent on https://github.com/mllam/neural-lam/pull/66 to be merged first.

khintz commented 1 month ago

Current issue is that in ar_model.py the figure is given as the same name for every epoch

if self.trainer.is_global_zero and not self.trainer.sanity_checking:
    for key, figure in log_dict.items():
        self.logger.log_image(key=key, images=[figure])

    plt.close("all")  # Close all figs

This works fine for wandb, for in mlflow, the figure gets overwritten so we only keep the last one. Looking for some iterator that I can pass to the custom mlflow log_image function.

khintz commented 1 month ago

I now got model metrics, system metrics and artifacts logging (including model logging) supported for mlflow. See e.g: https://mlflow.dmidev.org/#/experiments/2/runs/aceb8c6c94844736844dc7d1c12aa57f

However I get this warning:

2024/10/07 12:04:38 WARNING mlflow.models.model: Model logged without a signature and input example. Please set `input_example` parameter when logging the model to auto infer the model signature.

I am calling a log_model function after trainer.fit

training_logger.log_model(model)

which is

def log_model(self, model):
    mlflow.pytorch.log_model(model, "model")

But I need to set the signature. From https://mlflow.org/docs/latest/model/signatures.html, it states:

In MLflow, a model signature precisely defines 
the schema for model inputs, outputs, 
and any additional parameters required for 
effective model operation.

It should be possible to use infer_signature() from mlflow (https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.infer_signature), but to work with data one needs to input the training data like signature = infer_signature(model, training_data). But the training dataset is probably too big to parse, and I am not sure I can get it via train_model.py. Could we manually infer a signature or should we discard giving a signature at all?

Any thoughts @joeloskarsson, @sadamov, @TomasLandelius ?

sadamov commented 1 month ago

@khintz Thanks for adding mlflow to the list of loggers, it's nice to give the user more choice. And clearly you already got most of the work done :rocket: . About this warning you are seeing: I don't think manually specifying the signatures is a good idea, as it is too error prone. How long would it take to use a single example as a signature to pass to mlflow with smth like this:

Modify CustomMLFlowLogger:

class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
    def __init__(self, experiment_name, tracking_uri, data_module):
        super().__init__(experiment_name=experiment_name, tracking_uri=tracking_uri)
        mlflow.start_run(run_id=self.run_id, log_system_metrics=True)
        mlflow.log_param("run_id", self.run_id)
        self.data_module = data_module

    def log_image(self, key, images):
        from PIL import Image
        temporary_image = f"{key}.png"
        images[0].savefig(temporary_image)
        mlflow.log_image(Image.open(temporary_image), f"{key}.png")

    def log_model(self, model):
        input_example = self.create_input_example()
        with torch.no_grad():
            model_output = model(*input_example)

        #TODO: Are we sure we can hardcode the input names?
        signature = infer_signature(
            {name: tensor.cpu().numpy() for name, tensor in zip(['init_states', 'target_states', 'forcing', 'target_times'], input_example)},
            model_output.cpu().numpy()
        )

        mlflow.pytorch.log_model(
            model,
            "model",
            input_example=input_example,
            signature=signature
        )

    def create_input_example(self):
        if self.data_module.val_dataset is None:
            self.data_module.setup(stage="fit")
        return self.data_module.val_dataset[0]
joeloskarsson commented 1 month ago

But the training dataset is probably too big to parse

From my understanding you don't need to feed the whole dataset to the model to infer this signature, only one example batch. Going by this, something like what @sadamov proposed should work. However:

I don't think manually specifying the signatures is a good idea, as it is too error prone.

I agree. Optimally we would even get rid of the hard-coded argument names in the zip from @sadamov 's code (but I don't have an immediate idea how to do that).

Something else to consider here is that there are additional important inputs that are necessary to make a forecast with the model (that do not enter as arguments when calling the model() function). These include in particular:

  1. Static inputs (grid static features) https://github.com/mllam/neural-lam/blob/de27e9a9676dbf3115ed7e2691493c73aa265fc6/neural_lam/models/ar_model.py#L48-L55
  2. The graph parts (edge_index + static graph features) https://github.com/mllam/neural-lam/blob/de27e9a9676dbf3115ed7e2691493c73aa265fc6/neural_lam/models/base_graph_model.py#L23-L31

I don't know if these (or rather their shape) should be considered for the third part of the model signature ("Parameters (params)"), or somehow also viewed as part of the input. But I also fear that including these might just make this complex enough that this signature is no longer particularly useful. I think we should be motivated by how useful we actually find this signature to be. If we just want to get rid of the warning maybe we don't have to worry about these.