chengzeyi / stable-fast

Best inference performance optimization framework for HuggingFace Diffusers on NVIDIA GPUs.
MIT License
1.05k stars 59 forks source link

Support for DeepCache #110

Open HoiM opened 5 months ago

HoiM commented 5 months ago

DeepCache is an optimization technique that reduces the computation of the UNet loop. I tested it with Stable Fast Diffusion and it is faster than Stable-Fast. It would be great if the DeepCache-simplified UNet loop be further optimized by Stable-Fast.

Currently I did the following experiments:

1, Set up DeepCache first, and then set up stable-fast. This throws an error:

Loading pipeline components...: 100%|██████████| 5/5 [00:01<00:00,  3.45it/s]
/opt/conda/lib/python3.8/site-packages/sfast/jit/overrides.py:21: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return func(*args, **kwargs)
/opt/conda/lib/python3.8/site-packages/sfast/jit/overrides.py:21: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return func(*args, **kwargs)
/opt/conda/lib/python3.8/site-packages/sfast/jit/overrides.py:21: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  return func(*args, **kwargs)
  0%|          | 0/25 [00:00<?, ?it/s]/opt/conda/lib/python3.8/site-packages/sfast/jit/overrides.py:21: TracerWarning: Converting a tensor to a Python list might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return func(*args, **kwargs)
/opt/conda/lib/python3.8/site-packages/sfast/jit/overrides.py:21: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return func(*args, **kwargs)
/home/yuhaiming04/.local/lib/python3.8/site-packages/DeepCache/extension/deepcache.py:41: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  self.cur_timestep = list(self.pipe.scheduler.timesteps).index(args[1].item())
  0%|          | 0/25 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "run_sf.py", line 75, in <module>
    main()
  File "run_sf.py", line 34, in main
    output_frames = model(
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py", line 499, in __call__
    noise_pred = self.unet(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/sfast/jit/trace_helper.py", line 51, in wrapper
    traced_m, call_helper = trace_with_kwargs(
  File "/opt/conda/lib/python3.8/site-packages/sfast/jit/trace_helper.py", line 25, in trace_with_kwargs
    traced_module = better_trace(TraceablePosArgOnlyModuleWrapper(func),
  File "/opt/conda/lib/python3.8/site-packages/sfast/jit/utils.py", line 32, in better_trace
    script_module = torch.jit.trace(func, *args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 798, in trace
    return trace_module(
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 1065, in trace_module
    module._c._create_method_from_trace(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/sfast/jit/trace_helper.py", line 154, in forward
    outputs = self.module(*orig_args, **orig_kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/sfast/jit/trace_helper.py", line 89, in forward
    return self.func(*args, **kwargs)
  File "/home/yuhaiming04/.local/lib/python3.8/site-packages/DeepCache/extension/deepcache.py", line 42, in wrapped_forward
    result = self.function_dict['unet_forward'](*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/diffusers/models/unet_spatio_temporal_condition.py", line 409, in forward
    emb = self.time_embedding(t_emb)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/diffusers/models/embeddings.py", line 226, in forward
    sample = self.linear_1(sample)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/diffusers/models/lora.py", line 430, in forward
    out = super().forward(hidden_states)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
  File "/opt/conda/lib/python3.8/site-packages/sfast/jit/overrides.py", line 21, in __torch_function__
    return func(*args, **kwargs)
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
Columns 1 to 10 0.0186  0.0142  0.0114  0.0081  0.0079  0.0087  0.0102  0.0079  0.0027  0.0082
-0.0037 -0.0025 -0.0031 -0.0030  0.0040 -0.0062 -0.0000 -0.0022  0.0049  0.0006
 0.0104  0.0051  0.0070  0.0078  0.0099  0.0080  0.0073  0.0021  0.0037  0.0015
 0.0011  0.0097  0.0065  0.0034 -0.0013 -0.0017  0.0099  0.0055  0.0079  0.0018
 0.0097  0.0049  0.0100  0.0078  0.0065  0.0089  0.0080  0.0072  0.0051  0.0030
 0.0022 -0.0023 -0.0012 -0.0019  0.0033  0.0010 -0.0009 -0.0005  0.0029  0.0059
-0.0150 -0.0115 -0.0076 -0.0129 -0.0011 -0.0055 -0.0080 -0.0019 -0.0025 -0.0031
-0.0079 -0.0064 -0.0112 -0.0108 -0.0046 -0.0113 -0.0089 -0.0028 -0.0039 -0.0014
-0.0108 -0.0101 -0.0108 -0.0030 -0.0046 -0.0089 -0.0041 -0.0049 -0.0043  0.0013
 0.0017 -0.0024 -0.0010 -0.0044  0.0027  0.0032 -0.0031 -0.0059  0.0025  0.0009
 0.0020 -0.0063 -0.0040 -0.0010  0.0014 -0.0085 -0.0023  0.0029 -0.0021 -0.0026
 0.0076  0.0065  0.0112  0.0007  0.0038  0.0059  0.0060 -0.0034  0.0057 -0.0042
-0.0012 -0.0044 -0.0071 -0.0046  0.0001 -0.0015 -0.0007 -0.0035 -0.0068 -0.0026
-0.0026  0.0066  0.0014  0.0031 -0.0014  0.0038  0.0042  0.0062  0.0019  0.0037
 0.0018  0.0018  0.0050  0.0049  0.0058  0.0087 -0.0021  0.0019  0.0035  0.0015
 0.0057  0.0127  0.0058  0.0102  0.0062  0.0042  0.0072  0.0049  0.0027  0.0073
 0.0065  0.0028  0.0003  0.0335  0.0313  0.0026  0.0222 -0.0078  0.0112 -0.0145
-0.0030 -0.0118 -0.0043 -0.0115 -0.0031 -0.0125 -0.0074 -0.0037  0.0027 -0.0035
 0.0062  0.0076  0.0041  0.0081  0.0098  0.0052  0.0014  0.0047  0.0032  0.0005
-0.0015 -0.0017  0.0043 -0.0011  0.0015 -0.0028  0.0005 -0.0013  0.0045 -0.0013
 0.0195  0.0179  0.0175  0.0160  0.0170  0.0194  0.0165  0.0199  0.0157  0.0105

(it print many numbers and i won't paste them all)

2, Set up stable-fast first, and then set up DeepCache. It ran successfully but the speed is the same as using only stable-fast. However, I think method 1 should be correct: modifying the computational graph (DeepCache) and then trace and optimize (Stable-Fast).

Appreciate your help!

chengzeyi commented 4 months ago

@HoiM I guess you should build a DeepCache pipeline manually. Try compiling the model at first, then use DeepCacheSDHelper to enable DeepCache

https://github.com/horseee/DeepCache/blob/master/main.py

harrydrippin commented 4 months ago

I think the Method 1 which @HoiM mentioned is correct too. I tested this on my setup, using SD v1.5 model and DeepCache with branch_id=0, interval=5 setting for checking purpose. This setting confirms dramatic image quality loss, if DeepCache is properly applied. But the speed was same as using only stable-fast, and image quality was not gone low.

This means DeepCache was NOT properly applied to the pipeline with Method 2, because I compiled my graph before modifying the computational graph. It will be great if we can discuss workaround about this. (Good if @chengzeyi has one!)

Thank you for sharing this amazing work :)