mllam / neural-lam

Neural Weather Prediction for Limited Area Modeling
MIT License
64 stars 24 forks source link

Idea for how to define neural-networks operating on graph #35

Open leifdenby opened 1 month ago

leifdenby commented 1 month ago

I've been thinking about whether we could construct the different neural-networks which update the graph. The ideas below are very tentative and I may also have fundamentally misunderstood parts of how this is supposed to work, so all feedback is very welcome.

To my mind there are two types of neural networks we used: 1) MLPs for creating embeddings of node/edge features into a common-sized latent-space across all embeddings and 2) for updating node embeddings using message-passing.

What I would like to achieve is code that:

I think there are basically three steps to this:

  1. Define which embedding networks to create, the number of embedders will depend on nodes/edges share the same features or not
  2. Define the message passing operations, i.e. which nodes the message passing communicates between (giving each operation a unique identifier)
  3. Define the order of the message-passing order

Below is my code example that tries to encapsulate all the information I think is needed to later in the code actually instantiate the neural-network models that do the numerical operations.

import pytorch_lightning as pl

class NewGraphModel(pl.LightningModule):
    def __init__(self):
        n_mesh_node_features = 3
        n_grid_node_features = 18
        n_edge_features = 3
        n_hidden_features = 64

        # create node and edge feature embedding networks, all will project
        # into an embedding space of size n_hidden_features
        embedders = [
            dict(
                node=dict(component="mesh"),
                n_input_features=n_mesh_node_features,
            ),
            dict(
                node=dict(component="grid"),
                n_input_features=n_grid_node_features,
            ),
            # use the same edge embedding network for all edges, assuming
            # that all edges have the same set of features
            dict(edge=dict(), n_input_features=n_edge_features),
        ]

        # create message-passing networks that update node embeddings,
        # these all assume that the embedding vectors are of the same size
        message_passers = dict(
            g2m=dict(src=dict(component="grid"), dst=dict(component="mesh")),
            m2m_up_1to2=dict(
                src=dict(component="mesh", level=1),
                dst=dict(component="mesh", level=2),
            ),
            m2m_down_2to1=dict(
                src=dict(component="mesh", level=2),
                dst=dict(component="mesh", level=1),
            ),
            m2m_inlevel_1=dict(
                src=dict(component="mesh", level=1),
                dst=dict(component="mesh", level=1),
            ),
            m2m_inlevel_2=dict(
                src=dict(component="mesh", level=2),
                dst=dict(component="mesh", level=2),
            ),
            m2g=dict(src=dict(component="mesh"), dst=dict(component="grid")),
        )

        # define the order in which messages are passed
        # here we do the up/down twice before decoding back to the grid
        message_passing_order = [
            "g2m",
            # m2m pass 1
            "m2m_up_1to2",
            "m2m_inlevel_2",
            "m2m_down_2to1",
            "m2m_inlevel_1",
            # m2m pass 2
            "m2m_up_1to2",
            "m2m_inlevel_2",
            "m2m_down_2to1",
            "m2m_inlevel_1",
            "m2g",
        ]

A few notes on this to explain what is going on:

I hope this isn't total nonsense @joeloskarsson and @sadamov :laughing: just trying to get the ball rolling

joeloskarsson commented 3 weeks ago

So there's a few things that you bring up here, and some are almost orthogonal.

W.r.t embedders: I agree that this could be a bit more structured. In some code I've been working with I have a function embedd_all that handles all embedding one node-set/edge-set at a time and that gives a good overview.

I struggle a bit to understand the other parts, so would be grateful for some clarifications.

is closer to the mathematical notation used in the neural-lam publication for the expressions that do the embedding and message-passing, i.e. shows which nodes/edges are being operated on

Could you explain how the current implementation looks different from the equations in the paper? I look for example at something like https://github.com/mllam/neural-lam/blob/879cfec1b49d0255ed963f44b3a9f55d42c9920a/neural_lam/models/base_graph_model.py#L147-L149 and to me this tells exactly which node sets messages are passed between and what edges to use. Is the idea that you want the logic of the forward pass of the model to be defined in the init function? To me that it makes sense to follow the practice of instantiating the network blocks in init and using them in forward. However, something I think we should change is to break things up into separate nn.Modules (e.g. the encoder, processor and decoder parts), so that the forward pass actually sits in a forward function, rather than something like predict_step.

allows me to easily see in one place the complete set of neural networks used in a given architecture

Have you thought about how this should work with the class hierarchy? I think that is something I fail to see. It makes sense when you write out the NewGraphModel here, but in practice you will never have a class like this where all the GNNs and MLPs are defined in the same init function. As a concrete example: The ARModel does not know if it is working with a hierarchical graph or not. Therefore it will not know if it should create 1 mesh edge embedders or $L$ (= number of levels) such embedders.

Connected to above, my understanding of the NewGraphModel example is that you create lists and dicts that describe all network components and then you use these as a blueprint to instantiate those components. Is that correct? Why would this be less convoluted and more understandable than just instantiating the components directly? Do you have an idea of how this would work with the class hierarchy? Would all subclasses append to e.g. self.embedders and what class is responsible for actually triggering the instantiation based on this blueprint?

is flexible by making easy to create new message passing operations

I interpret this as being able to create new GNN layers, is that correct? I think that is very important and something we should think about. If we keep the current function signature from the forward of the InteractionNetworks this only comes down to changing the instantiation of GNN layers. I have done a little bit of this in our probabilistic modelling and it is very easy to swap out different GNN layer classes.

leifdenby commented 3 weeks ago

W.r.t embedders: I agree that this could be a bit more structured. In some code I've been working with I have a function embedd_all that handles all embedding one node-set/edge-set at a time and that gives a good overview.

Ok, could you point me to this? Just so we're on the same page: what I thought would be nice to have was a single point where we in a sense register what embedding networks that will be constructed. This could be achieve in different ways, for example

1) have a single collection of embedding "blueprints" that define all embedding networks to construct. I hope I've understand this term the way you use it, but my use here would be a definition of the number of input and output features, the type of the embedding (number of layers, width, or maybe we just come up with some common set of names to describe this) and an identifier for each (e.g. "graph_nodes")

2) define a collection for the constructing embedding networks to be stored in, this could simply be a dictionary with each key being the identifier for a given embedding network, and we encourage people to put their embedding networks in this dict

I would prefer option 1 as this would allow us to easily print what embedding networks are initiated, easy to understand how to add more, and would enforce that they work identically.

In container type for these graph based models this could be implemented with something like:

from torch import nn

class EncodeProcessDecodeGraph:
    def __init__(self):
        self._embedding_blueprints = {}
        self._embedding_networks = {}

    def _register_embedder(
        self, identifier, n_features_in, n_features_out, kind
    ):
        self._embedding_blueprints[identifier] = dict(
            n_features_in=n_features_in,
            n_features_out=n_features_out,
            kind=kind,
        )

    def _construct_embedders(self):
        for identifier, blueprint in self._embedding_blueprints.items():
            n_in = blueprint["n_features_in"]
            n_out = blueprint["n_features_out"]
            if blueprint["kind"] == "linear_single":
                self._embedding_networks[identifier] = nn.Linear(n_in, n_out)
            else:
                raise ValueError(f"Unknown kind: {blueprint['kind']}")

class KeislerGraph(EncodeProcessDecodeGraph):
    def __init__(
        self, hidden_dim_size=512, n_grid_features=10, n_edge_features=2
    ):
        super().__init__()

        self._register_embedder(
            identifier="grid_node",
            n_features_in=n_grid_features,
            n_features_out=hidden_dim_size,
            kind="linear_single",
        )
        self._register_embedder(
            identifier="g2m_and_m2g_edge",
            n_features_in=n_edge_features,
            n_features_out=hidden_dim_size,
            kind="linear_single",
        )
        self._register_embedder(
            identifier="m2m_edge",
            n_features_in=n_edge_features,
            n_features_out=hidden_dim_size,
            kind="linear_single",
        )

        self._construct_embedders()

I will create a separate comment for the message-passing networks

joeloskarsson commented 3 weeks ago

The embedd_all function that I mentioned is more about how the embedders are then applied to the input, so quite orthogonal to this. It is in the code for the ensemble model so not on Github yet. But I hope to have that pushed today, then I'll link it here.

Overall I think this looks nice. This could act almost as a wrapper for utils.make_mlp, but specific to MLPs that are embedders, so we can better keep track of them and make sure they are constructed in similar ways. There will be some work needed to make this fit into the model class hierarchy (existing or proposed new one), but I think it should be doable.

One thing I am wondering though: Could we not just immediately create each embedder, instead of registering them first and then having a separate call for constructing them? If we have something like _create_embedder it could both create the embedder and store it in self._embedding_networks immediately. That way we still keep track of all embedders in the same way, but avoids an additional function call that requires some understanding of the whole register->construct setup. I think that would also play more nicely with the model class hierarcy, as there will be no question which class is responsible for calling self._construct_embedders().

joeloskarsson commented 2 weeks ago

Here is the embedd_all function: https://github.com/mllam/neural-lam/blob/89c5ce938e5f05a31a81422c4198066a55ebaeaf/neural_lam/models/graph_efm.py#L295-L372

But again, that is a bit orthogonal since it is about how the embedders are applied to grid-input + graph, rather than how they are created. Perhaps of more interest is how the embedders are created in that class:

https://github.com/mllam/neural-lam/blob/89c5ce938e5f05a31a81422c4198066a55ebaeaf/neural_lam/models/graph_efm.py#L49-L138

This is at least collected all in one place, but could be even nicer with something like what's proposed above.