Closed timmoon10 closed 1 month ago
/te-ci pytorch
Looking more closely, I was a little too pessimistic in my investigation of ONNX exporting. ONNX export does appear to work correctly if the FP8 scales are initialized outside of the export process and it can correctly convert FP8 scale buffers to Constant operations. The issue I saw in https://github.com/NVIDIA/TransformerEngine/pull/820 was because we copy an FP8 scale with Tensor.copy_
, which is translated into the Expand operation (I think to deal with array broadcasting). This expand op is trivial but ONNX isn't smart enough to remove it. In any case, the simplest fix is to replace the Tensor.copy_
with a Tensor.fill_
during ONNX exports (see https://github.com/NVIDIA/TransformerEngine/pull/820/commits/4fdd63cb5568b60760d0e8a737af2f360f1d204b).
Description
ONNX export currently assumes that FP8 scales can be represented with constant operations, which requires that scales are initialized during the export process. However, we expect that the scales are initialized and updated during training. This PR uses slice operations to access the correct FP8 scales.
These changes are also included in https://github.com/NVIDIA/TransformerEngine/pull/820.
Type of change
Changes
Checklist: