NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.85k stars 309 forks source link

[PyTorch] Remove implicit padding and unpadding in `GroupedLinear` #984

Closed yaox12 closed 2 months ago

yaox12 commented 3 months ago

Description

This PR removes implicit padding and unpadding in GroupedLinear to:

  1. Match other layers (Linear/LayerNormLinear/LayerNormMLP) that throw an error when dims are not supported for FP8 GEMM, instead of doing implicit padding.
  2. Avoid redundant padding/unpadding between multiple GroupedLinear layers. Users can do the padding and unpadding manually before and after the whole module if needed.

Type of change

Changes

Please list the changes introduced in this PR:

Checklist:

yaox12 commented 3 months ago

@phu0ngng Could you please take a look? This should be a minor fix. Thank you very much.

phu0ngng commented 3 months ago

/te-ci pytorch