Closed abhinavgoel95 closed 2 years ago
It basically makes use of https://github.com/google/praxis/blob/2e46886e5582e39a65a871439ccab29b40dffe93/praxis/layers/repeats.py#L64 Repeat layers, which have very nice features such as nn.scan
, which reduces your overall XLA graph size and thus compilation time, and nn.remat
, improving performance by trading device memory for compute time.
I wondering if anyone knew the purpose of the
USE_REPEATED_LAYER
flag in c4.py. Thanks. :)