apax-hub / apax

A flexible and performant framework for training machine learning potentials.
MIT License
12 stars 1 forks source link

Dataparallel Training #257

Closed M-R-Schaefer closed 3 months ago

M-R-Schaefer commented 3 months ago

I have added automatic dataparallel training. It can be disabled with a config option, but otherwise training will now use all available devices. The implementation is based on Jax's sharding API. Parallel ensemble training is not supported yet, but this can be added in the future.

M-R-Schaefer commented 3 months ago

I have verified that this yields speedups by using multiple CPU devices.