tumaer / JAXFLUIDS

Differentiable Fluid Dynamics Package
Other
286 stars 51 forks source link

Distributed computation using Jaxfluids #7

Closed cjblackout closed 1 month ago

cjblackout commented 1 month ago

Hi, I'm trying to run jaxfluids on a PBS cluster with 8 GPUs per node. I want to run the simulation on 2 nodes, but I can't anything in the documentation regarding distributed computation for jaxfluids. Is this feature not yet supported?

aaronbuhendwa commented 1 month ago

Hi, to run parallel simulations, the user must specify the domain decomposition in the case setup .json. For example, to run a simulation on 2 devices, add the following key to "domain" in the case_setup. : "decomposition": { "split_x": 2, "split_y": 1, "split_z": 1 } This will split the domain in two equally sized blocks in x-direction, where each GPU will process a single block. We added some documentation of the case setup and numerical setup under notebooks/basics. In the upcoming weeks, we will complete the docs and add more notebooks to document JAX-Fluids.

cjblackout commented 1 month ago

The mutli-"GPU" workflow works, and it utilizes all GPUs, but my issue related to multi-node multi-GPU trianing.

aaronbuhendwa commented 1 month ago

For multi-node jobs, you must initialize the jax distributed system at the start of your script before calling any jaxfluids functions. https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html. For more information on using jax in a multi-node setting, we refer to https://jax.readthedocs.io/en/latest/multi_process.html. We have tested jaxfluids in multi-node settings on TPUs and SLURM GPU clusters.

cjblackout commented 1 month ago

Thank you, it worked, closing the issue