google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458 stars 69 forks source link

What does USE_REPEATED_LAYER do? #3

Closed abhinavgoel95 closed 2 years ago

abhinavgoel95 commented 2 years ago

I wondering if anyone knew the purpose of the USE_REPEATED_LAYER flag in c4.py. Thanks. :)

jysohn23 commented 2 years ago

It basically makes use of https://github.com/google/praxis/blob/2e46886e5582e39a65a871439ccab29b40dffe93/praxis/layers/repeats.py#L64 Repeat layers, which have very nice features such as nn.scan, which reduces your overall XLA graph size and thus compilation time, and nn.remat, improving performance by trading device memory for compute time.