mllam / weather-model-graphs

Tooling for creating, visualising and storing data-driven weather-model graphs
https://join.slack.com/t/ml-lam/shared_invite/zt-2t112zvm8-Vt6aBvhX7nYa6Kbj_LkCBQ
9 stars 9 forks source link

Make save function more universal to accept any number of 1D or 2D node or edge features #31

Closed maxiimilian closed 1 month ago

maxiimilian commented 1 month ago

Describe your changes

This pull request contains three changes -- one major and two minor changes.

The major change modifies how edge and node features are converted from pyg.data.Data objects into torch.Tensors before saving. If we assume that all node/edge features (features in the following) are numeric, a conversion to float32 will always work. Further, my proposed change assumes that the features can be either 1D (vector of edge lengths) or 2D (node position, vdiff vector between source and target nodes). Based on this, I propose a universal function that reshapes any 1D vector into a 2D column vector, which allows the concatenation of any number of features in a loop.

Using the new function, we also make the saving of len as edge feature explicit by including it in the default edge_features list.

Two minor bugs are fixed:

Type of change

Checklist before requesting a review

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

Author checklist after completed review

Checklist for assignee