keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 117 forks source link

Reimplement the data parallel distribution. #785

Closed qlzh727 closed 1 year ago

qlzh727 commented 1 year ago
  1. Fill the body of base Distribution.
  2. Port the existing data parallel distribution from jax backend tocommon backend.
  3. Update the variable/trainer to use the backend function to distribute variable/data.
  4. Update the unit tests
  5. Remove the existing jax specific distribution.py and tests.
qlzh727 commented 1 year ago

FYI, the follow up change for layout map is at https://github.com/qlzh727/keras-core/pull/1, and I am working on the model parallel distribution now.

qlzh727 commented 1 year ago

Also FYI, the model parallel distribution in https://github.com/qlzh727/keras-core/pull/2.