Closed timmoon10 closed 1 month ago
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
Could you comment on how the change from your last commit helped with the unittest failures? The change from list comprehension to the for loop should not change the behavior, right?
/te-ci pytorch
Description
This PR modifies the operation-based API (https://github.com/NVIDIA/TransformerEngine/pull/707) to support some simple branching behavior: operations can now accept extra tensor inputs and generate extra tensor outputs. This enables fusions like GEMMs with
beta=1
:Support for multiple inputs will also be necessary for cross-attention (and SSMs?). Note that we are not planning to support more complicated structures since that will take us down the road of general graph compilers.
Type of change
Changes
beta=1
Checklist: