Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

Implement GroupNorm to invoke APEX GroupNorm for NeMo Stable Diffusion AutoEncoder performance #468

Open athitten opened 4 months ago

athitten commented 4 months ago

🐛 Bug

Applying thunder.jit to the AutoEncoder stage in NeMo's Stable Diffusion is slower than the eager mode (takes 0.198s per train step) where as thunder.jit 0.242s takes per step.

To Reproduce

Steps to reproduce the behavior:

  1. Apply the attached git diff to NeMo encoder.patch

  2. Run NeMo using the command below:

    python examples/multimodal/text_to_image/stable_diffusion/sd_train.py trainer.precision=16 trainer.num_nodes=1 trainer.devices=1 ++exp_manager.max_time_per_run=00:00:03:00 trainer.max_steps=20 model.micro_batch_size=1 model.global_batch_size=1 model.data.synthetic_data=True exp_manager.exp_dir=/workspace/TestData/multimodal/stable_diffusion_train model.inductor=False model.cond_stage_config._target_=nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder ++model.cond_stage_config.version=openai/clip-vit-large-patch14 ++model.cond_stage_config.max_length=77 ~model.cond_stage_config.restore_from_path ~model.cond_stage_config.freeze ~model.cond_stage_config.layer model.unet_config.from_pretrained=null model.first_stage_config.from_pretrained=null model.unet_config.use_flash_attention=False model.unet_config.attention_resolutions=\[1\] model.unet_config.channel_mult=\[1\]

    The last trace of the encode step can be found in the attached log. last_trace_SD_encoder.log

Solution

The encoder.patch attached also rewrites GroupNorm to torch.nn.GroupNorm, which ends up moving the operator from using APEX groupnorm to using eager's groupnorm. This is what causes the performance drop in stable diffusion.

We should be mapping GroupNorm via either mechanism to APEX's GroupNorm.

cc: @tfogal

tfogal commented 4 months ago

@parthmannan we could use some help here. Can you help us identify why we're causing slowdowns?

parthmannan commented 4 months ago

Yep, I just need a couple days to wrap an urgent task I have been occupied with and can take a look at this early next week.

tfogal commented 4 months ago

convo w/ Abhishree:

parthmannan commented 4 months ago

There are some permutes---do we really need these? Is this an impact of our current (lack of) layout algorithm?

I am yet to run this myself but the permutes seem okay. The permute seems to be happening before a BMM layer and one of the tensors would need to be permuted to perform the op. However, I am not sure if this could just be a view like operation which doesn't need a GPU kernel. Will know more once I can run. The trace is super helpful to have though.

parthmannan commented 4 months ago

Looks like it is not (exactly) Thunder that's causing the slowdown. Removing thunder.jit once the above patch is applied results in no noticeable slowdown. The slowdown comes from the patch which disables the APEX GroupNorm and replaces it with torch GroupNorm. So, although Thunder isn't directly causing the slowdown, the patch has been applied to make the model Thunder compatible (from what I understand). The APEX kernel is far more optimized and nvFuser generated kernel runtime is about 2-2.5x slower than APEX. To retain similar speedup in Thunder, we need to either a) match the APEX GroupNorm performance using nvFuser or b) Pattern match torch.nn.GroupNorm to APEX GroupNorm. or c) Register APEX GroupNorm so that Thunder can trace through it. (probably easiest)

tfogal commented 4 months ago

Looks like it is not (exactly) Thunder ...

Thanks so much for your analysis, Parth! This is super helpful.

The slowdown comes from the patch which disables the APEX GroupNorm and replaces it with torch GroupNorm.

Ahh, I imagine this was put in because thunder broke on the GroupNorm operator? @athitten can you confirm?

or b) Pattern match torch.nn.GroupNorm to APEX GroupNorm. or c) Register APEX GroupNorm so that Thunder can trace through it. (probably easiest)

For this option, is it single-op -> single-op pattern matching? I ask because pattern matching a single operator happens to be a subcase of pattern matching that could be implemented much more easily / could be done quickly with an ad hoc solution. But I agree (c) is the best route forward, just curious.

mruberry commented 4 months ago

triage review —

parthmannan commented 4 months ago

For this option, is it single-op -> single-op pattern matching?

Yes this is matching torch.nn.GroupNorm op to APEX GroupNorm function and can be an option for us. Ideally, we do both b) and c) so that we can understand any code using APEX GroupNorm already and also replace torch GroupNorm with APEX GroupNorm automatically.

tfogal commented 4 months ago

Thanks, Parth.

Editing title and the original issue comment to reflect triage discussion: we need to implement GroupNorm in the APEX executor.

athitten commented 4 months ago

Ahh, I imagine this was put in because thunder broke on the GroupNorm operator? @athitten can you confirm?

Yes torch.nn.GroupNorm was put in because thunder broke on the GroupNorm operator. Using the torch GroupNorm operator in both cases and comparing with thunder.jit and without thunder.jit, I see thunder.jit being slower by 30 ms per iteration. Is this slowdown okay @tfogal @parthmannan ?

parthmannan commented 4 months ago

@athitten Can you post the logs where you see 30ms slowdown? Is this for a single iteration? In my tests on H100, the entire iteration was 50ms and the slowdown was on the order ~1-1.5 ms.

athitten commented 4 months ago

Hi @parthmannan here are the logs with thunder.jit and without for FP32: autoencoder_with_thunder_fp32.log autoencoder_wo_thunder_fp32.log I am running on a single A100. The difference is around 20 ms for FP32.

However with BF16, thunder.jit is slower by around 37ms. Here are the logs for BF16: autoencoder_with_thunder_jit_bf16.log autoencoder_wo_thunder_jit_bf16.log

parthmannan commented 4 months ago

@athitten Yea, I think you are right. I was able to test on A6000 and Thunder is slower even without the APEX GroupNorm. I need to dig further for that. Can you tell me the patch you used to generate the thunder trace? Where in the NeMo code do we insert the Thunder trace calls? That'll be helpful.

parthmannan commented 4 months ago

I think I kind of know where the problem is but we need better understanding of NeMo code to solve this. Thunder is running the Encoder in FP32 even when precision is set to BF16. Even without Thunder, the model gets the input in FP32 but it converts it to BF16 before running GEMMs but Thunder GEMMs are running everything in FP32. Where in NeMo is the precision set and where is AMP being taken care of? Seems like it doesn't kick in when using Thunder.

IvanYashchuk commented 4 months ago

I think I kind of know where the problem is but we need better understanding of NeMo code to solve this. Thunder is running the Encoder in FP32 even when precision is set to BF16. Even without Thunder, the model gets the input in FP32 but it converts it to BF16 before running GEMMs but Thunder GEMMs are running everything in FP32. Where in NeMo is the precision set and where is AMP being taken care of? Seems like it doesn't kick in when using Thunder.

Thunder should recognize PyTorch's autocast context and transform the initial trace accordingly. Here's an example:

In [1]: import thunder

In [2]: import torch

In [3]: a = torch.randn((3, 3), device="cuda")

In [4]: @thunder.jit
   ...: def func(a, b): return a @ b

In [5]: with torch.autocast("cuda", dtype=torch.bfloat16):
   ...:     out = func(a, a)
   ...: 

In [6]: out.dtype
Out[6]: torch.bfloat16

In [7]: thunder.last_traces(func)[-1]
Out[7]: 
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a: "cuda:0 f32[3, 3]"
  # b: "cuda:0 f32[3, 3]"
  [t0, t1] = nvFusion0(a, b)
    # t0 = prims.convert_element_type(a, dtypes.bfloat16)  # t0: "cuda:0 bf16[3, 3]"
    # t1 = prims.convert_element_type(b, dtypes.bfloat16)  # t1: "cuda:0 bf16[3, 3]"
  del a, b
  t2 = torch.matmul(t0, t1)  # t2: "cuda:0 bf16[3, 3]"
    # t2 = ltorch.matmul(t0, t1)  # t2: "cuda:0 bf16[3, 3]"
      # t2 = prims.matmul(t0, t1)  # t2: "cuda:0 bf16[3, 3]"
  del t0, t1
  return t2

It's best to print the initial trace, check if this code path is taken https://github.com/Lightning-AI/lightning-thunder/blob/4cc7b64ecbb9b9a28081c09ea00cf61093e57d9b/thunder/__init__.py#L554 and that there are rules for all ops that are expected to be downcasted https://github.com/Lightning-AI/lightning-thunder/blob/4cc7b64ecbb9b9a28081c09ea00cf61093e57d9b/thunder/core/transforms.py#L3685-L3719