Is your feature request related to a problem? Please describe.
Currently, there are only eager and data parallel distributed strategies, I think it's good in general to implement tensor parallel and pipeline parallel strategy so large model can be trained efficiently using this framework.
Describe the solution you'd like
We can use jax sharding api to perform sharded computation.
Is your feature request related to a problem? Please describe. Currently, there are only eager and data parallel distributed strategies, I think it's good in general to implement tensor parallel and pipeline parallel strategy so large model can be trained efficiently using this framework.
Describe the solution you'd like We can use jax sharding api to perform sharded computation.