I have added automatic dataparallel training. It can be disabled with a config option, but otherwise training will now use all available devices. The implementation is based on Jax's sharding API.
Parallel ensemble training is not supported yet, but this can be added in the future.
I have added automatic dataparallel training. It can be disabled with a config option, but otherwise training will now use all available devices. The implementation is based on Jax's sharding API. Parallel ensemble training is not supported yet, but this can be added in the future.