poets-ai / elegy

A High Level API for Deep Learning in JAX
https://poets-ai.github.io/elegy/
MIT License
469 stars 32 forks source link

[Feature Request] Distributed strategies: Tensor Parallel and Pipeline Parallel #254

Open rxng8 opened 2 months ago

rxng8 commented 2 months ago

Is your feature request related to a problem? Please describe. Currently, there are only eager and data parallel distributed strategies, I think it's good in general to implement tensor parallel and pipeline parallel strategy so large model can be trained efficiently using this framework.

Describe the solution you'd like We can use jax sharding api to perform sharded computation.