pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.5k stars 344 forks source link

🐛 [Bug] Encountered bug when using Torch-TensorRT - SD1.5 compilation causes incorrect output #3003

Open cehongwang opened 2 months ago

cehongwang commented 2 months ago

Bug Description

Stable Diffusion 1.5 fails to be compiled correctly using dynamo.compile. In early June (when I opened refitter-support branch), SD 1.5 could compile correctly, but after I did git pull main and updated the repo, the result was incorrect.

To Reproduce

Steps to reproduce the behavior:

1.script is


with torch.no_grad():
    model_id = "runwayml/stable-diffusion-v1-5"
    device = "cuda:0"

    # Instantiate Stable Diffusion Pipeline with FP16 weights
    pipe = DiffusionPipeline.from_pretrained(
        model_id, revision="fp16", torch_dtype=torch.float16
    )

    backend = "torch_tensorrt"

    model = pipe.unet
    model.half()
    model.to(device)

    inputs = torch.load("/opt/torch_tensorrt/refitting/sample_input.pt")
    exp_program = torch.export.export(model, tuple(inputs))

    enabled_precisions = {torch.float16}
    debug = True
    workspace_size = 20 << 30
    min_block_size = 1
    use_python_runtime = False
    trt_gm = torch_trt.dynamo.compile(
        exp_program,
        tuple(inputs),
        use_python_runtime=use_python_runtime,
        enabled_precisions=enabled_precisions,
        debug=debug,
        min_block_size=min_block_size,
        make_refitable=True,
    ) 

    diff = (trt_gm(*inputs)['sample'] - model(*inputs)[0])
    print(diff.std())
    print(diff.abs().mean())
    print(diff.max())

The result I got is

tensor(0.2996, device='cuda:0', dtype=torch.float16)
tensor(0.2394, device='cuda:0', dtype=torch.float16)
tensor(1.1650, device='cuda:0', dtype=torch.float16)
  1. Sample input file is here sample_input.zip

  2. If I change everything to float 32, the result becomes

    tensor(0.0828, device='cuda:0')
    tensor(0.0745, device='cuda:0')
    tensor(0.4428, device='cuda:0')

Expected behavior

The result of trt_module should be the same as the pytorch model

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context

cehongwang commented 2 months ago

https://github.com/pytorch/TensorRT/commit/1ca262f9fb1c4d95eb6478245364b1fb09f15e3f#diff-9cbf43cfb62cdf682eef77f6fd5cabc488bb5a91ff09bfee5d593827542b2476

After this pull request, the result becomes incorrect

cehongwang commented 2 months ago

Please check group_norm and native_group_norm implementation