yang-song / score_inverse_problems

Official repo for "Solving Inverse Problems in Medical Imaging with Score-Based Generative Models"
214 stars 26 forks source link

multi node multi gpu #1

Closed JiahaoYao closed 2 years ago

JiahaoYao commented 2 years ago

Hi @yang-song , thanks for sharing the code. I wonder whether the code supports multi node multi gpu training. I see the code include jax.device and jax.local_device. But I wonder how connection among the gpus is set up? Thanks again!

yang-song commented 2 years ago

it supports multi gpu single node training. It should work for multi node training with TPUs. I'm not sure JAX supports multi node GPU training now, but maybe it does

JiahaoYao commented 2 years ago

thanks @yang-song !