NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.85k stars 309 forks source link

[Common/PyTorch] Grouped GEMM via multi-stream cuBLAS #853

Closed yaox12 closed 3 months ago

yaox12 commented 4 months ago

Description

Grouped GEMM for fp32/bf16/fp16/fp8 via multi-stream cuBLAS. This is for MoE training.

Type of change

Changes

Please list the changes introduced in this PR:

Checklist:

phu0ngng commented 4 months ago

/te-ci pytorch

yaox12 commented 4 months ago

Can you trigger the CI again?

yaox12 commented 4 months ago

Hi @phu0ngng, I just finished implementing the GroupedLinear layer and would like to include it in this PR. Could you please review it? Sorry for not having the code to be reviewed all at once.

phu0ngng commented 3 months ago

/te-ci pytorch

phu0ngng commented 3 months ago

LGTM!

phu0ngng commented 3 months ago

/te-ci pytorch

phu0ngng commented 3 months ago

/te-ci pytorch