Helmholtz-AI-Energy / propulate

Propulate is an asynchronous population-based optimization algorithm and software package for global optimization and hyperparameter search on high-performance computers.
https://doi.org/10.1007/978-3-031-32041-5_6
BSD 3-Clause "New" or "Revised" License
32 stars 6 forks source link

Remove `MPI.COMM_WORLD.barrier()` in `get_data_loaders` function in `torch_example.py` #113

Closed mcw92 closed 6 months ago

mcw92 commented 7 months ago

In the torch_example.py tutorial, the get_data_loaders function is called in the ind_loss function, i.e., every time an individual is evaluated. However, the MPI.COMM_WORLD.barrier in the get_dataloaders function, which should prevent conflicts when potentially downloading the dataset for the first time in parallel, imposes an explicit synchronization point during the Propulate optimization, that we definitely do not want to have. This should be fixed / removed.

vtotiv commented 7 months ago

A possible solution would be to add an attribute to get_data_loaders(), which is set independently for every rank and checked before calling Barrier().

I solved it like this for the two examples in tutorials/surrogate/:

    if not hasattr(get_data_loaders, "barrier_called"):
        MPI.COMM_WORLD.Barrier()

        setattr(get_data_loaders, "barrier_called", True)