Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 61 forks source link

TE: fix the placement of bwd_sync symbol in trace #591

Closed kshitij12345 closed 2 weeks ago

kshitij12345 commented 2 weeks ago

I find that the te_sync_fp8_meta_bwd seems to appear after the return in the bwd trace

As observed by @kiya00, te_sync_fp8_meta_bwd appears after return in the backward trace and hence doesn't do anything. This PR fixes this by moving it before the return.

As to why this worked even when it was placed incorrectly is because - we called FP8GlobalStateManager.is_first_fp8_module() in the __init__ of the TELinear which happens during the compilation phase and not during the execution (where the fp8_autocast context manager is active). This flag is set and reset only under fp8_autocast manager (see the reference below). So, we didn't really consume the token and relied on TE to automatically introduce this sync.

https://github.com/Lightning-AI/lightning-thunder/blob/69e80f0a094376576a39306f62b9c510138e41fa/thunder/executors/transformer_engineex.py#L185

Also, to reiterate the reason we want to control these sync ourselves is because reordering operations can lead to failure if we rely on TE to automatically sync. See https://github.com/Lightning-AI/lightning-thunder/pull/80 for more context

Reference to where fp8_autocast sets the IS_FIRST_FP8_MODULE flag on entering the ctx_manager: https://github.com/NVIDIA/TransformerEngine/blob/b5a7c9f95e995236afae301366ec433c73b52690/transformer_engine/pytorch/fp8.py#L383-L411


Tested on TE v1.6 (stable), v1.7 and v1.8 (main)