NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.6k stars 255 forks source link

[Common] Remove CheckTensor if the workspace is empty in cast_transpose_fused #931

Closed phu0ngng closed 2 weeks ago

phu0ngng commented 2 weeks ago

Description

In JAX, we call dbias_cast_transpose and dact_dbias_cast_transpose with empty tensors to get the workspace size before calling the actual functions. So this fix is needed.

Type of change

Changes

phu0ngng commented 2 weeks ago

Hi, I added another small change that fixed the issue:

FAILED tests/jax/test_single_gpu_mnist.py::TestMNIST::test_te_bf16 - ValueError: The repository for mnist contains custom code which must be executed to correctly ...

This issue does not relate to the previous issue, but without this fix the CI won't come back clean so I decided to push it to the same PR.

phu0ngng commented 2 weeks ago

/te-ci jax