Closed yoeripoels closed 2 years ago
Hi, Thank you for the beautiful paper & code!
I noticed a small issue in the construction of dt in the decoders. The PDE time/grid parameters are loaded from the train dataset in train.py as follows: https://github.com/brandstetter-johannes/MP-Neural-PDE-Solvers/blob/510fa0bf94815ab85ca36fa3eb48ac0558bc77ee/experiments/train.py#L167-L170 However, in the decoder pde.dt is also used: https://github.com/brandstetter-johannes/MP-Neural-PDE-Solvers/blob/510fa0bf94815ab85ca36fa3eb48ac0558bc77ee/experiments/models_gnn.py#L207 Since this is not set, currently pde.dt is always the value of the (default) constructor of the PDE object rather than what was used to generate the dataset. Since the dt is constant throughout the dataset this is only a constant scaling and leads to no issues, but it might be nice to have it correct :) A simple fix would be to simply add pde.dt = train_dataset.dt to train.py.
dt
pde.dt
pde.dt = train_dataset.dt
train.py
Cheers, Yoeri
Hi Yoeri, you are completely right! Thank you so much for spotting this. Do you want to do a PR, I would merge then... Johannes
sure, I've made a PR!
Thanks a lot for spotting that!
Hi, Thank you for the beautiful paper & code!
I noticed a small issue in the construction of
dt
in the decoders. The PDE time/grid parameters are loaded from the train dataset in train.py as follows: https://github.com/brandstetter-johannes/MP-Neural-PDE-Solvers/blob/510fa0bf94815ab85ca36fa3eb48ac0558bc77ee/experiments/train.py#L167-L170 However, in the decoderpde.dt
is also used: https://github.com/brandstetter-johannes/MP-Neural-PDE-Solvers/blob/510fa0bf94815ab85ca36fa3eb48ac0558bc77ee/experiments/models_gnn.py#L207 Since this is not set, currentlypde.dt
is always the value of the (default) constructor of the PDE object rather than what was used to generate the dataset. Since the dt is constant throughout the dataset this is only a constant scaling and leads to no issues, but it might be nice to have it correct :) A simple fix would be to simply addpde.dt = train_dataset.dt
totrain.py
.Cheers, Yoeri