coreweave / ml-containers

MIT License
19 stars 3 forks source link

feature request: TransformerEngine in torch-extras #62

Closed ad8e closed 4 months ago

ad8e commented 5 months ago

I'd like to try out TransformerEngine. The reason is that I profiled our Tensor Parallel layers and the communication happens sequentially with the computation, which makes it very slow as the number of TP shards increases. Megatron supports overlapping communication with computation, but it calls out to TransformerEngine to make it happen. I wasn't able to find any other way to overlap unless I used nvidia primitives.

I wasn't able to install TE in my own image, because the Github Actions run times out. It worked using this command to install the dev build: RUN export MAX_JOBS=4 && export NVTE_WITH_USERBUFFERS=1 && export NVTE_FRAMEWORK=pytorch && pip3 install -v git+https://github.com/NVIDIA/TransformerEngine.git@main 4 jobs was a total guess. Not sure that NVTE_FRAMEWORK=pytorch does anything.

There's a catch here. TE 1.4, the latest stable which was released 5 days ago, forces FlashAttention<=2.4.2. The version check was recently bumped to 2.5.6: https://github.com/NVIDIA/TransformerEngine/commit/965803c9dcf889f7e9820921cb811c2bf9b91dbc but that probably won't appear in TE stable for a while.

So if you add TE stable, then FA will have a version restriction. On the other hand, the nightly torch-extras image had a broken PyTorch 2/2 times I tried it. So it's up to you which image might want TE added, and which version of TE.

The image I normally use is torch-extras, NCCL, CUDA 12.2.2 (latest), Ubuntu 22.04.

ad8e commented 4 months ago

We're not exploring TE anymore.

wbrown commented 4 months ago

Thank you. We'll add TE in the near future, as other customers may want it.