siliconflow / onediff

OneDiff: An out-of-the-box acceleration library for diffusion models.
https://github.com/siliconflow/onediff/wiki
Apache License 2.0
1.68k stars 100 forks source link

[Bug] compile model with nexfort backend on sm75 device #1060

Closed zhangvia closed 3 months ago

zhangvia commented 3 months ago

Your current environment information

PyTorch version: 2.3.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OneFlow version: path: ['/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/oneflow'], version: 0.9.1.dev20240724+cu118, git_commit: f230775, cmake_build_type: Release, rdma: True, mlir: True, enterprise: False
Nexfort version: 0.1.dev258
OneDiff version: 1.2.0.dev1
OneDiffX version: 1.2.0.dev1

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.30.1
Libc version: glibc-2.31

Python version: 3.10.0 (default, Mar  3 2022, 09:58:08) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.4.0-197-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 2080 Ti

Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   46 bits physical, 48 bits virtual
CPU(s):                          32
On-line CPU(s) list:             0-31
Thread(s) per core:              2
Core(s) per socket:              8
Socket(s):                       2
NUMA node(s):                    2
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           79
Model name:                      Intel(R) Xeon(R) CPU E5-2620 v4 @ 2.10GHz
Stepping:                        1
CPU MHz:                         2325.914
CPU max MHz:                     3000.0000
CPU min MHz:                     1200.0000
BogoMIPS:                        4200.02
Virtualization:                  VT-x
L1d cache:                       512 KiB
L1i cache:                       512 KiB
L2 cache:                        4 MiB
L3 cache:                        40 MiB
NUMA node0 CPU(s):               0-7,16-23
NUMA node1 CPU(s):               8-15,24-31
Vulnerability Itlb multihit:     KVM: Mitigation: Split huge pages
Vulnerability L1tf:              Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds:               Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown:          Mitigation; PTI
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Full generic retpoline, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Mitigation; Clear CPU buffers; SMT vulnerable
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb invpcid_single intel_pt ssbd ibrs ibpb stibp kaiser tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdseed adx smap xsaveopt cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts md_clear flush_l1d

Versions of relevant libraries:
[pip3] diffusers==0.25.1
[pip3] numpy==1.26.4
[pip3] onnx==1.15.0
[pip3] onnxruntime-gpu==1.17.1
[pip3] onnxsim==0.4.36
[pip3] open-clip-torch==2.24.0
[pip3] pytorch-lightning==2.2.1
[pip3] torch==2.3.0+cu118
[pip3] torchaudio==2.3.0+cu118
[pip3] torchmetrics==1.3.2
[pip3] torchvision==0.18.0+cu118
[pip3] transformers==4.41.2
[pip3] triton==2.3.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] open-clip-torch           2.24.0                   pypi_0    pypi
[conda] pytorch-lightning         2.2.1                    pypi_0    pypi
[conda] torch                     2.3.0+cu118              pypi_0    pypi
[conda] torchaudio                2.3.0+cu118              pypi_0    pypi
[conda] torchmetrics              1.3.2                    pypi_0    pypi
[conda] torchvision               0.18.0+cu118             pypi_0    pypi
[conda] triton                    2.3.0                    pypi_0    pypi

🐛 Describe the bug

File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py", line 1442, in __call__
    down_block_res_samples, mid_block_res_sample = self.controlnet(
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/media/74nvme/research/onediff/src/onediff/infer_compiler/backends/nexfort/deployable_module.py", line 27, in forward
    return self._deployable_module_model(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/diffusers/pipelines/controlnet/multicontrolnet.py", line 32, in forward
    def forward(
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 917, in forward
    return compiled_fn(full_args)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 89, in g
    return f(*args)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 106, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 152, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 906, in __call__
    return self.get_current_callable()(inputs)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 784, in run
    return model(new_inputs)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 934, in _run_from_cache
    return compiled_graph.compiled_artifact(inputs)
  File "/tmp/torchinductor_root/6h/c6hzlvjgjp5flb7mbkmhx6hygmcbqsk34uthznwrlzacmr5imt7c.py", line 6131, in call
    buf47 = torch.ops.nexfort_cuda.cuda_timestep_embedding.default(buf46, 2, 160, 0, 160, torch.float32, device(type='cuda', index=0), False, -9.210340371976184, 160.0, True)
  File "/media/74nvme/software/miniconda3/envs/stable-fast/lib/python3.10/site-packages/torch/_ops.py", line 594, in __call__
    return self_._op(*args, **kwargs)
RuntimeError: Expected err == cudaSuccess to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

reproduce

import torch
from diffusers import StableDiffusionControlNetInpaintPipeline,ControlNetModel
from onediff.infer_compiler import oneflow_compile,compile
from diffusers.utils import load_image
from os.path import join as opj
import cv2
import numpy as np
from PIL import Image

model_path = "/media/74nvme/research/models/"
model_path_base = "/media/74nvme/research/models"
canny_controlnet_path = opj(model_path_base,"controlnet/control_v11p_sd15_canny")

# cn_path = [canny_controlnet_path,depth_controlnet_path,hand_controlnet_path,tile_controlnet_path]
cn_path = [canny_controlnet_path]
controlnets = []

for i in cn_path:
  controlnets.append(ControlNetModel.from_pretrained(i,torch_dtype=torch.float16))

image = Image.open("/media/74nvme/research/test_img/805565039-0.png").resize((512,768))
sd_mask_image = Image.open("/media/74nvme/research/test_img/805565039-0-mask.jpg").resize((512,768))
controlnet_conditioning_image = [Image.fromarray(cv2.Canny(np.array(image),100,200))]
ip_adapter_image = None

pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(model_path, controlnet=controlnets, torch_dtype=torch.float16).to('cuda')
# pipe.load_ip_adapter(opj(model_path_base,"controlnet/ip_adapter/"), subfolder="./", weight_name="ip-adapter_sd15.bin")
# pipe.set_ip_adapter_scale(0.9)
pipe.unet = compile(pipe.unet,backend='nexfort')
# pipe.text_encoder = oneflow_compile(pipe.text_encoder)
# pipe.controlnet = oneflow_compile(pipe.controlnet)
pipe.controlnet = compile(pipe.controlnet,backend='nexfort')

# pipe.set_ip_adapter_scale(0)
pipe(
  prompt="a girl",
  image=image.resize((1024,744)),
  mask_image=sd_mask_image.resize((1024,744)),
  control_image=[image.resize((1024,744)) for image in controlnet_conditioning_image],
  ip_adapter_image=ip_adapter_image.resize((1024,744)) if ip_adapter_image is not None else None,
  output_type="pil",
  controlnet_conditioning_scale=[0.5]
)
# image.save("./res2_nocompile.jpg")

以上代码在rtx4090上没有问题,nexfort不支持sm75是吗

strint commented 3 months ago

Please have a try:

zhangvia commented 3 months ago

thanks, it worked

strint commented 3 months ago

@lixiang007666

lixiang007666 commented 3 months ago

As you tried, the issue appears in torch 2.3.0 but does not occur in the nightly version. However, this bug has been fixed in the latest nexfort release. It can now be used ontorch 2.3.0 as well.

Additionally, the NEXFORT_FUSE_TIMESTEP_EMBEDDING is turned on by default. If you encounter a failure in the fuse timestep_embedding optimization, you can try turning off this switch.

@zhangvia @strint

zhangvia commented 3 months ago

i try the 0.1.dev260, there is still error. after set the environment variables, the error gone