kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

How does it work? #187

Closed ayaka14732 closed 2 years ago

ayaka14732 commented 2 years ago

I am trying to train huge models with JAX and multiple TPUs, and that's why I found this repo. However, after reading the README I still not understand how it works ( e.g. how to integrate JAX with Ray, how to achieve model parallelism, etc). Would you mind explaining a bit more? Thanks!

whoislimshady commented 2 years ago

Were you able to figure it out?

ayaka14732 commented 2 years ago

Currently I know xmap is for distributed training, and Ray is an RPC library for communication between hosts. The overall method is similar to Megatron-LM, so I am reading the paper for now.

kingoflolz commented 2 years ago

I gave a talk on this topic here: https://youtu.be/ZCMOPkcTu3s, and the slides are here https://docs.google.com/presentation/d/1t8P-WgiBdTGsd3CfsgzYOjDQtuBo4b-7GNBnKfaK5iw/edit?usp=sharing