NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.85k stars 309 forks source link

[PyTorch] Branching operations #1027

Closed timmoon10 closed 1 month ago

timmoon10 commented 2 months ago

Description

This PR modifies the operation-based API (https://github.com/NVIDIA/TransformerEngine/pull/707) to support some simple branching behavior: operations can now accept extra tensor inputs and generate extra tensor outputs. This enables fusions like GEMMs with beta=1:

model = te.Sequential(
    MakeExtraOutput(),
    Linear(...),
    AddInPlace(),
)
y, linear_in = model(x, linear_out)  # GEMM with beta=1 into linear_out
...
loss.backward()  # dgrad GEMM with beta=1 into linear_in.grad

Support for multiple inputs will also be necessary for cross-attention (and SSMs?). Note that we are not planning to support more complicated structures since that will take us down the road of general graph compilers.

Type of change

Changes

Checklist:

timmoon10 commented 2 months ago

/te-ci pytorch

timmoon10 commented 2 months ago

/te-ci pytorch

timmoon10 commented 2 months ago

/te-ci pytorch

timmoon10 commented 2 months ago

/te-ci pytorch

ptrendx commented 1 month ago

Could you comment on how the change from your last commit helped with the unittest failures? The change from list comprehension to the for loop should not change the behavior, right?

timmoon10 commented 1 month ago

/te-ci pytorch