chengzeyi / stable-fast

Best inference performance optimization framework for HuggingFace Diffusers on NVIDIA GPUs.
MIT License
1.05k stars 59 forks source link

Segmentation fault instead of CUDA OOM for SVD with higher batch size #100

Open yondonfu opened 6 months ago

yondonfu commented 6 months ago

I noticed that using a higher batch size with SVD that exceeds the available GPU VRAM triggers a Segmentation fault after the first warm up call that triggers compilation. If I re-run the same test without using stable-fast compilation, I get a CUDA OOM error instead. I was expecting to see the CUDA OOM error when there is insufficient GPU VRAM available while using stable-fast.

Here is the output when using stable-fast to compile the model:

Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 10.25it/s]
/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/sfast/jit/overrides.py:21: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return func(*args, **kwargs)
/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/sfast/jit/overrides.py:21: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return func(*args, **kwargs)
/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/sfast/jit/overrides.py:21: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  return func(*args, **kwargs)
/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/sfast/utils/flat_tensors.py:275: TracerWarning: torch.Tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  return super().__new__(cls, x, *args, **kwargs)
/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/sfast/jit/overrides.py:21: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  return func(*args, **kwargs)
  0%|                                                                                                                                                                                                                  | 0/25 [00:00<?, ?it/s]/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/sfast/jit/overrides.py:21: TracerWarning: Converting a tensor to a Python list might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return func(*args, **kwargs)
/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/sfast/jit/overrides.py:18: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  args = (args[0], args[1].item(), *args[2:])
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:52<00:00,  2.11s/it]
  0%|                                                                                                                                                                                                                  | 0/25 [00:00<?, ?it/s]
Segmentation fault (core dumped)

Here is the output when not using stable-fast to compile the model:

Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 10.17it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:46<00:00,  1.87s/it]
  0%|                                                                                                                                                                                                                  | 0/25 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/user/ai-worker/jobs/containers/svd-xt-film/test_svd.py", line 56, in <module>
    frames = pipeline(
             ^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py", line 502, in __call__
    noise_pred = self.unet(
                 ^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/diffusers/models/unet_spatio_temporal_condition.py", line 434, in forward
    sample, res_samples = downsample_block(
                          ^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/diffusers/models/unet_3d_blocks.py", line 2173, in forward
    hidden_states = attn(
                    ^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/diffusers/models/transformer_temporal.py", line 351, in forward
    hidden_states = block(
                    ^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/diffusers/models/attention.py", line 393, in forward
    ff_output = self.ff(norm_hidden_states, scale=lora_scale)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/diffusers/models/attention.py", line 665, in forward
    hidden_states = module(hidden_states, scale)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/diffusers/models/activations.py", line 103, in forward
    return hidden_states * self.gelu(gate)
                           ^^^^^^^^^^^^^^^
  File "/home/user/mambaforge/envs/svd-xt-film/lib/python3.11/site-packages/diffusers/models/activations.py", line 96, in gelu
    return F.gelu(gate)
           ^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 4.39 GiB. GPU 0 has a total capacty of 23.65 GiB of which 880.06 MiB is free. Including non-PyTorch memory, this process has 22.78 GiB memory in use. Of the allocated memory 19.81 GiB is allocated by PyTorch, and 2.50 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Here is the test script:

from diffusers import StableVideoDiffusionPipeline
import torch
import time
from PIL import Image
from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConfig

def compile_model(model):
    config = CompilationConfig.Default()

    # xformers and Triton are suggested for achieving best performance.
    # It might be slow for Triton to generate, compile and fine-tune kernels.
    try:
        import xformers

        config.enable_xformers = True
    except ImportError:
        print("xformers not installed, skip")
    # NOTE:
    # When GPU VRAM is insufficient or the architecture is too old, Triton might be slow.
    # Disable Triton if you encounter this problem.
    try:
        import triton

        config.enable_triton = True
    except ImportError:
        print("Triton not installed, skip")

    model = compile(model, config)
    return model

repo_id = "stabilityai/stable-video-diffusion-img2vid-xt"
cache_dir = "./cache"
pipeline = StableVideoDiffusionPipeline.from_pretrained(
    repo_id, cache_dir=cache_dir, variant="fp16", torch_dtype=torch.float16
)
pipeline.to("cuda")

# Comment the following the line to disable stable-fast
pipeline = compile_model(pipeline)

image = ["input/1.png"]

generator = torch.manual_seed(42)

# Warm up call for stable-fast compilation with batch size = 1
frames = pipeline(
    [Image.open(i).convert("RGB") for i in image],
    decode_chunk_size=8,
    generator=generator,
).frames

# batch size = 4 for the actual call
image *= 4

begin = time.time()
frames = pipeline(
    [Image.open(i).convert("RGB") for i in image],
    decode_chunk_size=8,
    generator=generator,
).frames
end = time.time()

run_time = end - begin
print(f"run time: {run_time:.3f}s")

peak_mem_allocated = torch.cuda.max_memory_allocated()
peak_mem_reserved = torch.cuda.max_memory_reserved()
print(f"peak GPU memory allocated: {peak_mem_allocated / 1024**3:.3f}GiB")
print(f"peak GPU memory reserved: {peak_mem_reserved / 1024**3:.3f}GiB")

The test was run on a RTX 4090.

chengzeyi commented 5 months ago

@yondonfu This might be a bug related to torch.jit. It tries to dump the source code location of the exception. However, unfortunately, the source code of the operation is missing since we have modified it dynamically.