Closed cjblackout closed 1 month ago
Hi, to run parallel simulations, the user must specify the domain decomposition in the case setup .json. For example, to run a simulation on 2 devices, add the following key to "domain" in the case_setup. : "decomposition": { "split_x": 2, "split_y": 1, "split_z": 1 } This will split the domain in two equally sized blocks in x-direction, where each GPU will process a single block. We added some documentation of the case setup and numerical setup under notebooks/basics. In the upcoming weeks, we will complete the docs and add more notebooks to document JAX-Fluids.
The mutli-"GPU" workflow works, and it utilizes all GPUs, but my issue related to multi-node multi-GPU trianing.
For multi-node jobs, you must initialize the jax distributed system at the start of your script before calling any jaxfluids functions. https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html. For more information on using jax in a multi-node setting, we refer to https://jax.readthedocs.io/en/latest/multi_process.html. We have tested jaxfluids in multi-node settings on TPUs and SLURM GPU clusters.
Thank you, it worked, closing the issue
Hi, I'm trying to run jaxfluids on a PBS cluster with 8 GPUs per node. I want to run the simulation on 2 nodes, but I can't anything in the documentation regarding distributed computation for jaxfluids. Is this feature not yet supported?