LouisSerrano / coral

MIT License
23 stars 4 forks source link

0. Official Code

Official PyTorch implementation of CORAL | Accepted at Neurips 2023 | Arxiv

To cite our work:

@article{serrano2023operator,
      title={Operator Learning with Neural Fields: Tackling PDEs on General Geometries}, 
      author={Louis Serrano and Lise Le Boudec and Armand Kassaï Koupaï and Thomas X Wang and Yuan Yin and Jean-Noël Vittaut and Patrick Gallinari},
      journal={37th Conference on Neural Information Processing Systems (NeurIPS 2023)},
      year={2023},
      url={https://arxiv.org/abs/2306.07266}
}

1. Code installation and setup

coral installation

conda create -n coral python=3.9.0
pip install -e .

install torch_geometric and torch_geometric extensions

pip install torch_geometric
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu117.html

setup wandb config example

add to your ~/.bashrc

export WANDB_API_TOKEN=your_key
export WANDB_DIR=your_dir
export WANDB_CACHE_DIR=your_cache_dir
export WANDB_CONFIG_DIR="${WANDB_DIR}config/wandb"
export MINICONDA_PATH=your_anaconda_path

2. Data

3. Run experiments

plot

The code runs only on GPU. We provide sbatch configuration files to run the training scripts. They are located in bash_static and bash_dynamics. We expect the user to have wandb installed in its environment to ease the 2-step training. For all tasks, the first step is to launch an inr.py training. The weights of the inr model are automatically saved under its run_name. For the second step, i.e. for training the dynamics or inference model, we need to use the previous run_name as input to the config file to load the inr model. We provide examples of the python scripts that need to be run.

IVP

plot

Design

python3 static/train/design_regression.py "data.dataset_name=airfoil" "inr.run_name=glowing-music-4181" 'optim.epochs=10000'

plot

python3 static/train/design_inr.py "data.dataset_name=elasticity" 'optim.batch_size=64' 'optim.epochs=5000' 'inr_in.w0=10' 'inr_out.w0=15' 'optim.lr_inr=1e-4' 'optim.meta_lr_code=1e-4' 
python3 static/train/design_regression.py "data.dataset_name=elasticity" "inr.run_name=clone-nerf-herder-4289" 'optim.epochs=10000'
python3 static/train/design_regression.py "data.dataset_name=pipe" "inr.run_name=super-plasma-4149" 'optim.epochs=10000' 'model.width=128' 'model.depth=3' 'inr.inner_steps=3' 

Dynamics modeling

python3 dynamics_modeling/train.py "data.sub_from=$sub_from" "data.same_grid=$same_grid" "data.dataset_name=$dataset_name" "dynamics.width=$width" "dynamics.depth=$depth" "data.sub_tr=$sub_tr" "data.sub_te=$sub_te" "optim.epochs=$epochs" "data.seq_inter_len=$seq_inter_len" "data.seq_extra_len=$seq_extra_len" "optim.batch_size=$batch_size" "optim.lr=$lr"  "dynamics.teacher_forcing_update=$teacher_forcing_update" "inr.run_name=$run_name"

with dataset_name='navier-stokes-dino' or 'shallow-water-dino'.