NVIDIA / modulus

Open-source deep-learning framework for building, training, and fine-tuning deep learning models using state-of-the-art Physics-ML methods
https://developer.nvidia.com/modulus
Apache License 2.0
795 stars 172 forks source link

GraphCast improvements - Part I #510

Closed mnabian closed 1 month ago

mnabian commented 1 month ago

Modulus Pull Request

Description

Closes https://github.com/NVIDIA/modulus/issues/506, https://github.com/NVIDIA/modulus/issues/505, https://github.com/NVIDIA/modulus/issues/486, https://github.com/NVIDIA/modulus/issues/508, https://github.com/NVIDIA/modulus/issues/509, https://github.com/NVIDIA/modulus/issues/511, https://github.com/NVIDIA/modulus/issues/516, https://github.com/NVIDIA/modulus/issues/517

Checklist

Dependencies

mnabian commented 1 month ago

/blossom-ci

stadlmax commented 1 month ago

@mnabian Since you are revisiting GraphCast now, adding a few comments

mnabian commented 1 month ago

@mnabian Since you are revisiting GraphCast now, adding a few comments

  • Can we add the option to use transformer_engine.LayerNorm? In AIFS benchmarks, we just could get a 1.3x end-to-end improvement from doing so since the PyTorch implementation is rather bad for the sizes we encounter in these workloads.
  • Can you check whether the current combination of MeshGraphNodeBlock and MeshGraphEdgeBlock actually matches the paper (https://github.com/NVIDIA/modulus/blob/main/modulus/models/graphcast/graph_cast_processor.py#L97-L98) I created a schematic of the GraphCast architecture for some Arch folks last week, and I think the order of residuals over the edges does not match the paper here. I might have made a mistake when trying to use shared primitives here the last time. The issue here is that in MeshGraphNet, EdgeBlock already applies the "residual" on the edge features, while the NodeBlock would expect then the features including the residual connection prior to message-passing while in GraphCast, all residual connections are only applied after both the updated edge and node features are computed (at least according to the paper).
  • What would you think of splitting the GraphCastNet into a GraphCastNetERA5 and a GraphCastNet model? The current issue I see with GraphCastNet is that it is very specific to the nature of the ERA5 dataset (e.g. when it comes to preparing the input and output to switch between the HxW layout and the typical "serial" graph layout. GraphCastNet then could be a rather data-agnostic model defining the operations on (g2m_graph, mesh_graph, m2g_graph), while GraphCastNetERA5 defines the things somewhat specific to the workload like checkpointing, input/output conversions, etc.. In the longer term, I think it really could make sense to try to make things a bit more modular. In particular, this also includes things like "history" or the actual "prediction" mode, i.e. whether GraphCastNetERA5 predicts y_t = f(x_t-1) or y_t = x_t - 1 + f(x_t-1). It could make sense if the "backbone" is agnostic to these things while having a specialized prediction wrapper.

Thanks @stadlmax , I'll add your comments to my epic and consider them all.

mnabian commented 1 month ago

Note to myself: API updates breaks GraphCast tests. Need to update them all.

mnabian commented 1 month ago

@stadlmax as far as I remember, we were using fused layernorm and that gave us nice speedup:https://github.com/NVIDIA/modulus/blob/main/modulus/models/gnn_layers/mesh_graph_mlp.py#L157... Did you also compare transformer_engine.LayerNorm with fused layernorm?

stadlmax commented 1 month ago

@stadlmax as far as I remember, we were using fused layernorm and that gave us nice speedup (although I can't find it in the most recent code)... Did you also compare transformer_engine.LayerNorm with fused layernorm?

Yes, for AIFS, I found TE > APEX > PyTorch throughout a bunch of usual sizes AIFS had in their RFI benchmark. Especially the backward kernels in TE are much better for our cases. (reported numbers are runtimes, lower is better)

num_channels = 256 layer_norm_impl 1626240 x 256 327660 x 256 40962 x 256 542080 x 256 814540 x 256
apex 9.75127 2.03821 0.371149 3.32072 4.9402
pytorch 10.752 4.17265 0.957743 3.63721 10.2774
transformer_engine 2.59236 0.580879 0.801795 0.916124 1.33596
num_channels = 384 layer_norm_impl 1626240 x 384 327660 x 384 40962 x 384 542080 x 384 814540 x 384
apex 11.2164 2.3109 0.359366 3.79922 5.64847
pytorch 11.8419 4.33466 0.583828 3.99414 10.6802
transformer_engine 3.98762 0.849599 0.396184 1.38306 2.022
num_channels = 512 layer_norm_impl 1626240 x 512 327660 x 512 40962 x 512 542080 x 512 814540 x 512
apex 12.1739 2.50785 0.37578 4.11927 6.14573
pytorch 12.7752 4.5477 0.615464 4.30874 11.2191
transformer_engine 4.90182 1.04243 0.391352 1.6877 2.4967
mnabian commented 1 month ago

@stadlmax as far as I remember, we were using fused layernorm and that gave us nice speedup (although I can't find it in the most recent code)... Did you also compare transformer_engine.LayerNorm with fused layernorm?

Yes, for AIFS, I found TE > APEX > PyTorch throughout a bunch of usual sizes AIFS had in their RFI benchmark. Especially the backward kernels in TE are much better for our cases. (reported numbers are runtimes, lower is better)

num_channels = 256

layer_norm_impl 1626240 x 256 327660 x 256 40962 x 256 542080 x 256 814540 x 256 apex 9.75127 2.03821 0.371149 3.32072 4.9402 pytorch 10.752 4.17265 0.957743 3.63721 10.2774 transformer_engine 2.59236 0.580879 0.801795 0.916124 1.33596 num_channels = 384

layer_norm_impl 1626240 x 384 327660 x 384 40962 x 384 542080 x 384 814540 x 384 apex 11.2164 2.3109 0.359366 3.79922 5.64847 pytorch 11.8419 4.33466 0.583828 3.99414 10.6802 transformer_engine 3.98762 0.849599 0.396184 1.38306 2.022 num_channels = 512

layer_norm_impl 1626240 x 512 327660 x 512 40962 x 512 542080 x 512 814540 x 512 apex 12.1739 2.50785 0.37578 4.11927 6.14573 pytorch 12.7752 4.5477 0.615464 4.30874 11.2191 transformer_engine 4.90182 1.04243 0.391352 1.6877 2.4967

This is great comparison, thanks! I'll switch to te then. Do we have any reason to still keep fused layernorm from apex, or we should just remove it?

stadlmax commented 1 month ago

This is great comparison, thanks! I'll switch to te then. Do we have any reason to still keep fused layernorm from apex, or we should just remove it?

I guess, no, not really. TE also should be decently covered when it comes to development specifically for Blackwell and beyond. I know a few POCs that try to optimize The LN in TE even further. If we are based on the DLFW containers, TE also should come pre-installed.

mnabian commented 1 month ago

@stadlmax added support for TE layernorm.

mnabian commented 1 month ago

Note to myself: API updates breaks GraphCast tests. Need to update them all.

Done

mnabian commented 1 month ago

/blossom-ci

mnabian commented 1 month ago

/blossom-ci

mnabian commented 1 month ago

/blossom-ci

mnabian commented 1 month ago

/blossom-ci

mnabian commented 1 month ago

/blossom-ci

mnabian commented 1 month ago

/blossom-ci

stadlmax commented 1 month ago

Thanks for addressing the feedback, looks good to me.

mnabian commented 1 month ago

/blossom-ci