Open athitten opened 4 months ago
@parthmannan we could use some help here. Can you help us identify why we're causing slowdowns?
Yep, I just need a couple days to wrap an urgent task I have been occupied with and can take a look at this early next week.
convo w/ Abhishree:
last_traces
log attached here!There are some permutes---do we really need these? Is this an impact of our current (lack of) layout algorithm?
I am yet to run this myself but the permutes seem okay. The permute seems to be happening before a BMM
layer and one of the tensors would need to be permuted to perform the op.
However, I am not sure if this could just be a view
like operation which doesn't need a GPU kernel. Will know more once I can run. The trace is super helpful to have though.
Looks like it is not (exactly) Thunder that's causing the slowdown. Removing thunder.jit
once the above patch is applied results in no noticeable slowdown.
The slowdown comes from the patch which disables the APEX GroupNorm and replaces it with torch GroupNorm. So, although Thunder isn't directly causing the slowdown, the patch has been applied to make the model Thunder compatible (from what I understand).
The APEX kernel is far more optimized and nvFuser generated kernel runtime is about 2-2.5x slower than APEX.
To retain similar speedup in Thunder, we need to either a) match the APEX GroupNorm performance using nvFuser or b) Pattern match torch.nn.GroupNorm
to APEX GroupNorm. or c) Register APEX GroupNorm so that Thunder can trace through it. (probably easiest)
Looks like it is not (exactly) Thunder ...
Thanks so much for your analysis, Parth! This is super helpful.
The slowdown comes from the patch which disables the APEX GroupNorm and replaces it with torch GroupNorm.
Ahh, I imagine this was put in because thunder broke on the GroupNorm operator? @athitten can you confirm?
or b) Pattern match torch.nn.GroupNorm to APEX GroupNorm. or c) Register APEX GroupNorm so that Thunder can trace through it. (probably easiest)
For this option, is it single-op -> single-op pattern matching? I ask because pattern matching a single operator happens to be a subcase of pattern matching that could be implemented much more easily / could be done quickly with an ad hoc solution. But I agree (c) is the best route forward, just curious.
triage review —
For this option, is it single-op -> single-op pattern matching?
Yes this is matching torch.nn.GroupNorm op to APEX GroupNorm function and can be an option for us. Ideally, we do both b) and c) so that we can understand any code using APEX GroupNorm already and also replace torch GroupNorm with APEX GroupNorm automatically.
Thanks, Parth.
Editing title and the original issue comment to reflect triage discussion: we need to implement GroupNorm in the APEX executor.
Ahh, I imagine this was put in because thunder broke on the GroupNorm operator? @athitten can you confirm?
Yes torch.nn.GroupNorm was put in because thunder broke on the GroupNorm operator. Using the torch GroupNorm operator in both cases and comparing with thunder.jit and without thunder.jit, I see thunder.jit being slower by 30 ms per iteration. Is this slowdown okay @tfogal @parthmannan ?
@athitten Can you post the logs where you see 30ms slowdown? Is this for a single iteration? In my tests on H100, the entire iteration was 50ms and the slowdown was on the order ~1-1.5 ms.
Hi @parthmannan here are the logs with thunder.jit and without for FP32:
autoencoder_with_thunder_fp32.log
autoencoder_wo_thunder_fp32.log
I am running on a single A100. The difference is around 20 ms for FP32
.
However with BF16
, thunder.jit
is slower by around 37ms. Here are the logs for BF16
:
autoencoder_with_thunder_jit_bf16.log
autoencoder_wo_thunder_jit_bf16.log
@athitten Yea, I think you are right. I was able to test on A6000 and Thunder is slower even without the APEX GroupNorm. I need to dig further for that. Can you tell me the patch you used to generate the thunder trace? Where in the NeMo code do we insert the Thunder trace calls? That'll be helpful.
I think I kind of know where the problem is but we need better understanding of NeMo code to solve this. Thunder is running the Encoder in FP32 even when precision is set to BF16. Even without Thunder, the model gets the input in FP32 but it converts it to BF16 before running GEMMs but Thunder GEMMs are running everything in FP32. Where in NeMo is the precision set and where is AMP being taken care of? Seems like it doesn't kick in when using Thunder.
I think I kind of know where the problem is but we need better understanding of NeMo code to solve this. Thunder is running the Encoder in FP32 even when precision is set to BF16. Even without Thunder, the model gets the input in FP32 but it converts it to BF16 before running GEMMs but Thunder GEMMs are running everything in FP32. Where in NeMo is the precision set and where is AMP being taken care of? Seems like it doesn't kick in when using Thunder.
Thunder should recognize PyTorch's autocast context and transform the initial trace accordingly. Here's an example:
In [1]: import thunder
In [2]: import torch
In [3]: a = torch.randn((3, 3), device="cuda")
In [4]: @thunder.jit
...: def func(a, b): return a @ b
In [5]: with torch.autocast("cuda", dtype=torch.bfloat16):
...: out = func(a, a)
...:
In [6]: out.dtype
Out[6]: torch.bfloat16
In [7]: thunder.last_traces(func)[-1]
Out[7]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a, b):
# a: "cuda:0 f32[3, 3]"
# b: "cuda:0 f32[3, 3]"
[t0, t1] = nvFusion0(a, b)
# t0 = prims.convert_element_type(a, dtypes.bfloat16) # t0: "cuda:0 bf16[3, 3]"
# t1 = prims.convert_element_type(b, dtypes.bfloat16) # t1: "cuda:0 bf16[3, 3]"
del a, b
t2 = torch.matmul(t0, t1) # t2: "cuda:0 bf16[3, 3]"
# t2 = ltorch.matmul(t0, t1) # t2: "cuda:0 bf16[3, 3]"
# t2 = prims.matmul(t0, t1) # t2: "cuda:0 bf16[3, 3]"
del t0, t1
return t2
It's best to print the initial trace, check if this code path is taken https://github.com/Lightning-AI/lightning-thunder/blob/4cc7b64ecbb9b9a28081c09ea00cf61093e57d9b/thunder/__init__.py#L554 and that there are rules for all ops that are expected to be downcasted https://github.com/Lightning-AI/lightning-thunder/blob/4cc7b64ecbb9b9a28081c09ea00cf61093e57d9b/thunder/core/transforms.py#L3685-L3719
🐛 Bug
Applying
thunder.jit
to theAutoEncoder
stage in NeMo's Stable Diffusion is slower than the eager mode (takes 0.198s per train step) where asthunder.jit
0.242s takes per step.To Reproduce
Steps to reproduce the behavior:
Apply the attached git diff to NeMo encoder.patch
Run NeMo using the command below:
The last trace of the encode step can be found in the attached log. last_trace_SD_encoder.log
Solution
The
encoder.patch
attached also rewritesGroupNorm
totorch.nn.GroupNorm
, which ends up moving the operator from using APEX groupnorm to using eager's groupnorm. This is what causes the performance drop in stable diffusion.We should be mapping GroupNorm via either mechanism to APEX's GroupNorm.
cc: @tfogal