apple / axlearn

An Extensible Deep Learning Library
Apache License 2.0
1.88k stars 269 forks source link

Transformer extend_step supports multi steps generation (2/2). #836

Closed ds-hwang closed 1 week ago

ds-hwang commented 1 week ago

In MultiheadAttention.extend_step, logit_bias was hardcoded to have a length of 1. This PR modified it to support multi-step inputs. This change also makes extend_step more aligned with forward, reducing the overall code complexity.

ds-hwang commented 1 week ago

Could you take a look? From 901

ds-hwang commented 1 week ago

Thank you for review!