HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
46 stars 6 forks source link

Scan #73

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

The compilation time is down from 600s (depth=32, qrnn_frequency=8) to 70s (depth=4, qrnn_frequency=8 -> total_depth=4*8).

Baseline: https://wandb.ai/homebrewnlp/gpt/runs/2ar16xdkjoz57mdvgipffnng335nik6t\ New: https://wandb.ai/homebrewnlp/gpt/runs/4hqwsrq7o9ic8x2bhp3wgu7ipdhk1h3y

Once #7 is finished and we run QRNN at every step, the compile time will drop as we don't have to manually unroll 8 steps anymore.

ClashLuke commented 2 years ago

Stable