sanchit-gandhi / whisper-jax

JAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU.
Apache License 2.0
4.45k stars 384 forks source link

Partitioning Spec` #65

Open jeromeku opened 1 year ago

jeromeku commented 1 year ago

Hi @sanchit-gandhi Was wondering if you could explain the standard partitioning rules? In particular, how are activation and parameter parallelism achieved through the various logical_axis_rules combinations of activation_dims and parameter_dims.

I've read the t5x partitioning documentation and the section on canonical axis rules but am still confused.

sanchit-gandhi commented 1 year ago

Does this provide any additional context: https://discuss.huggingface.co/t/pjit-basics-flax/23344/2