apple / axlearn

An Extensible Deep Learning Library
Apache License 2.0
1.88k stars 269 forks source link

Transformer extend_step supports multi steps generation. #831

Closed ds-hwang closed 1 week ago

ds-hwang commented 1 week ago

Streaming encoder and streaming synthesizer require multi-step extend_step:

This is likely a common requirement for multimodal streaming encoders and synthesizers. Additionally, it serves as a prerequisite for speculative decoding or a funnel transformer decoder.

Performance benchmark

I benchmarked it on TPUv4 in the Notebook TPU. This change doesn't change performance much (some little faster, other little slower).

ASIS

---------------------------------------------------------------------------------------
Benchmark                                             Time             CPU   Iterations
---------------------------------------------------------------------------------------
QkvLinearExtendStepBenchmark/2048/16/1024/1        1.22 ms        0.444 ms         1497
QkvLinearExtendStepBenchmark/2048/16/4096/1        3.29 ms        0.494 ms          927
QkvLinearExtendStepBenchmark/2048/16/32768/1       23.6 ms         1.07 ms          158
QkvLinearExtendStepBenchmark/2048/16/4096/8        N/A   Note: multi step benchmark
QkvLinearExtendStepBenchmark/2048/16/4096/64       N/A
QkvLinearExtendStepBenchmark/2048/16/4096/512      N/A

This PR

----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
QkvLinearExtendStepBenchmark/2048/16/1024/1         1.70 ms        0.513 ms         1125
QkvLinearExtendStepBenchmark/2048/16/4096/1         3.40 ms        0.519 ms         1174
QkvLinearExtendStepBenchmark/2048/16/32768/1        20.1 ms        0.930 ms          404
QkvLinearExtendStepBenchmark/2048/16/4096/8         3.68 ms        0.524 ms         1139
QkvLinearExtendStepBenchmark/2048/16/4096/64        3.74 ms        0.532 ms         1125
QkvLinearExtendStepBenchmark/2048/16/4096/512       2530 ms         80.4 ms            1

If remove the weird moveaxis hack, there is further speed up, especially when step size is big (512).

This PR w/o moveaxis

----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
QkvLinearExtendStepBenchmark/2048/16/1024/1         1.52 ms        0.542 ms         1082
QkvLinearExtendStepBenchmark/2048/16/4096/1         3.18 ms        0.547 ms         1096
QkvLinearExtendStepBenchmark/2048/16/32768/1        19.6 ms        0.824 ms          430
QkvLinearExtendStepBenchmark/2048/16/4096/8         3.34 ms        0.542 ms         1139
QkvLinearExtendStepBenchmark/2048/16/4096/64        3.48 ms        0.553 ms         1091
QkvLinearExtendStepBenchmark/2048/16/4096/512       36.5 ms         1.71 ms           71
ds-hwang commented 1 week ago

PTAL? from 894

ds-hwang commented 1 week ago

Hi, there are few following PRs. Could you review it?