NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.6k stars 255 forks source link

[PyTorch] Release GIL in PyTorch extensions #938

Closed timmoon10 closed 2 weeks ago

timmoon10 commented 2 weeks ago

Description

This PR releases the GIL within PyTorch extensions to enable some multithreaded workflows (see https://github.com/NVIDIA/TransformerEngine/issues/868). As far as I'm aware of, the only time we call Python from within C++ is when initializing Userbuffers (i.e. when using NCCL to bootstrap UB comms).

Closes https://github.com/NVIDIA/TransformerEngine/issues/868.

Type of change

Changes

Checklist:

timmoon10 commented 2 weeks ago

/te-ci pytorch

denera commented 2 weeks ago

Is there a way to test this?

We can manually validate that comm+GEMM overlap still works with this PR by running examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py. Everything else should be safe because they don't do any Python calls from C.