Closed mcw92 closed 6 months ago
A possible solution would be to add an attribute to get_data_loaders()
, which is set independently for every rank and checked before calling Barrier()
.
I solved it like this for the two examples in tutorials/surrogate/
:
if not hasattr(get_data_loaders, "barrier_called"):
MPI.COMM_WORLD.Barrier()
setattr(get_data_loaders, "barrier_called", True)
In the
torch_example.py
tutorial, theget_data_loaders
function is called in theind_loss
function, i.e., every time an individual is evaluated. However, theMPI.COMM_WORLD.barrier
in theget_dataloaders
function, which should prevent conflicts when potentially downloading the dataset for the first time in parallel, imposes an explicit synchronization point during the Propulate optimization, that we definitely do not want to have. This should be fixed / removed.