mllam / neural-lam

Neural Weather Prediction for Limited Area Modeling
MIT License
95 stars 36 forks source link

Implement Model Sharding #70

Open sadamov opened 3 weeks ago

sadamov commented 3 weeks ago

Description:

We should implement model sharding in neural-lam to allow for training with larger batch sizes without exhausting GPU vRAM. This feature will enable users to scale to larger models and improve training efficiency. Using high-res datasets with many input feature channels quickly exhausts even GPUs with 100GB vRAM. Currently, the batch-size must be reduced to 1-4 on many systems for such datasets.

Proposed implementation:

  1. Add sharding logic to model definition (see here for bipartite_subgraph & here)
  2. Provide configuration options for sharding strategy as train_model flag

Benefits:

Technical considerations:

This feature will significantly enhance neural-lam's capabilities for large-scale atmospheric modeling.

joeloskarsson commented 3 weeks ago

A few thoughts about sharding (and also about saving VRAM in general):