Hi @yang-song , thanks for sharing the code. I wonder whether the code supports multi node multi gpu training. I see the code include jax.device and jax.local_device. But I wonder how connection among the gpus is set up? Thanks again!
it supports multi gpu single node training. It should work for multi node training with TPUs. I'm not sure JAX supports multi node GPU training now, but maybe it does
Hi @yang-song , thanks for sharing the code. I wonder whether the code supports multi node multi gpu training. I see the code include
jax.device
andjax.local_device
. But I wonder how connection among the gpus is set up? Thanks again!