CemOezcan / hyper-graph-nets

Implementation of different remote message passing strategies in graph neural networks for mesh-based physical simulation.
MIT License
6 stars 0 forks source link

HyperGraphNet

Learned physics simulators utilizing Graph Neural Networks (GNNs) achieve faster inference times and have enhanced generalization capabilities when compared to classical physics simulators. However, the ability of GNNs to capture long-range dependencies is limited by a constant number of GNN layers (or message passing layers). To overcome this issue, we designed HyperGraphNets, a framework for GNNs that uses remote message passing to facilitate modeling long-range dependencies.

This repository contains the implementation of multiple GNN architectures as learned physics simulators.

Examples

The following simulations and error curves illustrate the enhanced remote message passing capabilities of HyperGraphNets (left) in comparison to the baseline, MeshGraphNets (right). MP denotes the number of message passing layers.

Deforming Plate

Flag Simple

Setting up the environment

This project uses PyPI for handling packages and dependencies. To get started, we recommend creating a new virtual environment and installing the required dependencies using pip install -r \Path\to\requirements.txt.

Recording

We provide logging of metrics and visualizations to W&B. This requires logging in to W&B via wandb login (For more information, read the quickstart guide of W&B).

Downloading datasets

Download a dataset to its respective sub folder within the data directory ./data. Let's consider the deforming_plate dataset:

Creating an experiment

The folder configs contains a number of .yaml files that describe the configuration of the task to run. To run an experiment from any of these files on a local machine, type python main.py "${CONFIG}", where ${CONFIG} refers to the name of the config file (without the suffix .yaml).

Folder structure

.
├── config                    # Config files for setting up experiments
├── data                      # Data sets, models and plots
    ├── cylinder_flow         # Input and output data for the cylinder_flow task
        ├── input 
        ├── output 
    ├── deforming_plate       # Input and output data for the deforming_plate task
        ├── input 
        ├── output 
    ├── flag_minimal          # Input and output data for the flag_minimal task
        ├── input 
        ├── output 
    ├── flag_simple           # Input and output data for the flag_simple task
        ├── input 
        ├── output
├── src                       # Source code
    ├── algorithms            # Training and evaluation of GNN-based physics simulators
    ├── data                  # Loading and preprocessing raw data sets
    ├── graph_balancer        # Graph balancing algorithms for remote message passing
    ├── migration             # PyTorch implementation of MeshGraphNets and its extension, HyoerGraphNets
    ├── model                 # Task specific Graph Neural Networks
    ├── rmp                   # Graph clustering algorithms for remote message passing
    ├── tasks                 # Wrapper methods for training, evaluating and visualizing 
├── wandb                     # Recording via W&B
├── .gitignore                 
├── requirements.txt           
├── main.py
├── download.sh 
├── LICENSE
└── README.md