A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
This is a quick fix that only attempts to import submodules if the corresponding DL framework is available. The original bug may reemerge if a Thunder user incidentally has JAX or PaddlePaddle in their Python environment.
Type of change
[ ] Documentation change (change only to the documentation, either a fix or a new content)
[x] Bug fix (non-breaking change which fixes an issue)
[ ] New feature (non-breaking change which adds functionality)
[ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
Changes
Do not attempt importing submodules if framework is not available
Description
https://github.com/NVIDIA/TransformerEngine/pull/839 breaks Lightning Thunder's JIT infrastructure (see https://github.com/NVIDIA/TransformerEngine/pull/839#issuecomment-2108824717). Thunder calls
inspect.stack
, which indiscriminately interacts with all of Transformer Engine's submodules, including lazily loaded submodules for JAX and PaddlePaddle that are intended to throw import errors.This is a quick fix that only attempts to import submodules if the corresponding DL framework is available. The original bug may reemerge if a Thunder user incidentally has JAX or PaddlePaddle in their Python environment.
Type of change
Changes
Checklist: