NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.6k stars 255 forks source link

[MoE][Common/PyTorch] Add permutation #936

Open StudyingShao opened 2 weeks ago

StudyingShao commented 2 weeks ago

Description

Permutation for fp32/bf16/fp16/fp8 data type. Now PyTorch op only.

Additional descriptions: https://github.com/fanshiqing/moe_grouped_gemm/tree/dev

Type of change

Changes

Please list the changes introduced in this PR:

Checklist:

StudyingShao commented 1 week ago

Hi @phu0ngng @cyanguwa , this PR is the Permutation fusion operators needed by MoE. Please ignore the unit test file tests/pytorch/test_permutation.py, and help to review other changes. Thanks. I will start to refactor the unit test file in parallel.

cc @QiZhangNV

phu0ngng commented 4 days ago

/te-ci pytorch

phu0ngng commented 2 days ago

Hi @StudyingShao, thanks for putting this work into TE. I have a couple of suggestions after the first glance at your code.

  1. Please sign off all of your commits (DCO failed).
  2. Please rewrite the unit test with pytest and enable skipping if FP8 is unavailable (see https://github.com/NVIDIA/TransformerEngine/blob/7326af9d8d7f7a9d2a4d24b0193d5bb51541a80d/tests/pytorch/test_numerics.py#L495).