chengzeyi / stable-fast

Best inference performance optimization framework for HuggingFace Diffusers on NVIDIA GPUs.
MIT License
1.16k stars 71 forks source link

exception when using deterministic generation with enable_cuda_graph = True #58

Closed lackofdream closed 10 months ago

lackofdream commented 10 months ago

when do deterministic generation by passing a torch.Generator to the pipeline, stable-fast raised an AssertionError

Traceback (most recent call last):
  File "/data/dbgsfast/dbgsfast/__init__.py", line 43, in <module>
    model(**kwarg_inputs, generator=torch.Generator(device="cuda").manual_seed(41))
  File "/data/dbgsfast/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/data/dbgsfast/.venv/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 958, in __call__
    image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
  File "/data/dbgsfast/.venv/lib/python3.10/site-packages/sfast/cuda/graphs.py", line 32, in dynamic_graphed_callable
    return cached_callable(*args, **kwargs)
  File "/data/dbgsfast/.venv/lib/python3.10/site-packages/sfast/cuda/graphs.py", line 143, in functionalized
    return _graphed_module(*user_args, **user_kwarg_args)
  File "/data/dbgsfast/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/dbgsfast/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/dbgsfast/.venv/lib/python3.10/site-packages/sfast/cuda/graphs.py", line 130, in forward
    outputs = self._forward(*inputs, **kwarg_inputs)
  File "/data/dbgsfast/.venv/lib/python3.10/site-packages/sfast/cuda/graphs.py", line 136, in _forward
    tree_copy_(static_kwarg_inputs, kwarg_inputs)
  File "/data/dbgsfast/.venv/lib/python3.10/site-packages/sfast/utils/copy.py", line 20, in tree_copy_
    tree_copy_(dest[k], src[k])
  File "/data/dbgsfast/.venv/lib/python3.10/site-packages/sfast/utils/copy.py", line 22, in tree_copy_
    assert dest == src

code to reproduce

import torch
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
from sfast.compilers.stable_diffusion_pipeline_compiler import (
    compile,
    CompilationConfig,
)

def load_model():
    model = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
    )

    model.scheduler = EulerAncestralDiscreteScheduler.from_config(
        model.scheduler.config
    )
    model.safety_checker = None
    model.to(torch.device("cuda"))
    return model

model = load_model()

config = CompilationConfig.Default()
config.enable_cuda_graph = True

model = compile(model, config)

kwarg_inputs = dict(
    prompt="(masterpiece:1,2), best quality, masterpiece, best detail face, a beautiful girl",
    num_inference_steps=30,
    num_images_per_prompt=1,
    height=512,
    width=512,
)

model(**kwarg_inputs, generator=torch.Generator(device="cuda").manual_seed(42))

model(**kwarg_inputs, generator=torch.Generator(device="cuda").manual_seed(41))
chengzeyi commented 10 months ago

@lackofdream Just fixed in the latest main