siliconflow / onediff

OneDiff: An out-of-the-box acceleration library for diffusion models.
https://github.com/siliconflow/onediff/wiki
Apache License 2.0
1.4k stars 85 forks source link

Compiled VAE encoder returns NaN #971

Open jamarju opened 1 week ago

jamarju commented 1 week ago

Describe the bug

A clear and concise description of what the bug is.

The encoded latent output by madebyollin/sdxl-vae-fp16-fix VAE returns all NaN for some images when compiled with Onediff. See reproducer below.

Your environment

OS

Linux

OneDiff git commit id

>>> print(onediff.__version__)
1.1.0.dev202405210127

OneFlow version info

path: ['/mnt/localnvme/home/javiermartin/miniconda3/envs/fc-ai-sdxl-inpainting-server/lib/python3.10/site-packages/oneflow']
version: 0.9.1.dev20240515+cu121
git_commit: ec7b682
cmake_build_type: Release
rdma: True
mlir: True
enterprise: False

How To Reproduce

import base64
import io

from diffusers import AutoencoderKL
from onediff.infer_compiler import oneflow_compile
from PIL import Image
import numpy as np
import torch

img_b64 = """
UklGRgYNAABXRUJQVlA4WAoAAAAQAAAAPwUA/wIAQUxQSE8AAAABDzD/ERFCUSNJEQ7w75TfOOCGiuj/
BBT+53/+53/+53/+53/+53/+53/+53/+53/+5//rZEuSyn/8x3+HBf/zP//zP//zP//zP//zf98CAFZQ
OCCQDAAAMAgBnQEqQAUAAz5tNphJJCKioSA06JCADYlpbvFW1d3KAGSyVZ3ivxVQcXKK6LmZWDDxWZmL
0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8
VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZm
L0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO
8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZ
mL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvT
O8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xW
ZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYv
TO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7x
WZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmY
vTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7
xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZm
YvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M
7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZ
mYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9
M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvF
ZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi
9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8Vl8HMRvSeOKuVoO
RHOSeV90oOt3AYwT2R5jxgnsjzHjBPZHmPGCeyPMeME9keY8YJ7I8x4wT2R5jxgnsjzHjBPZHmPGCfMi
9M7xWZmL0zvFZlwBbhhUzi2DNCdoUBspEXCX6CF/XPlRs3yTC76uO0K6o2b5Jhd9XHaFdUbN8kwu+rjt
CuqNm+SYXfVx2hXVGzfJMLvq47QrqjZvkmF31cdoV1Rs34U7xWZmL0zvFZmYvKCjI4avJYBP02Zr2bof
o+KEDoYJ7I8x4wT2R5jxgnsjzHjBPZHmPGCeyPMeME9keY8YJ7I8x4wT2R5jxgnsjzHjBPZHmPGCfMi9
M7xWZmL0zvFZlwBl2GRTLvh9Wwzhh0LgDhY4BEm4k1cz5Jd+17RePzktXM+SXfte0Xj85LVzPkl37XsB
c4zMvad4rMzF6Z0g3Rm29ZWXrqfXyTC76uO0K6o2b5Jhd9XHaFdUbN8kwu+rjtCuqNm+SYXfVx2hXVGz
fJMLvq47QrqjZvkmF31cdoV1Rs3yTC76uPVPxWZmL0zvFZmYvScTs/KcrCzk1tquQdq/uc7Yra79r2i8
fnJauZ8ku/a9ovH5yWrmfJLv2vaLx+clpAJl7TvFZmYvTO8VhlYHNXl9oV1Rs3yTC76uO0K6o2b5Jhd9
XHaFdUbN8kwu+rjtCuqNm+SYXfVx2hXVGzfJMLvq47QrqjZvkmF31cdoV1Rs34U7xWZmL0zvFZmYvTO8
VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZm
L0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO
8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZ
mL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvT
O8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xW
ZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYv
TO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7x
WZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7xWZmL0zvFZmY
vTO8VmZi9M7xWZmL0zvFZmYvTO8VmZi9M7woAAD+/+5ygAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAH7
+/IkVrupmCQZP7skl8J+IXsdGJ8ZMZnWRIlkhO4mPD3XeN79WL2JA0ShB3ObLReiskfzVXKCG3z68UY4
RvJInE6DYimyGafY+NxSckR8odVM2/Gxo578/DXIdoeQ7Q8h2h5DtDyHaHkO0PIdoeQ7Q8h2h5DtDyHa
HkO0PIdoeQ7Q8h2h5DtDyHaHkO0PIdoeQ7Q8h2h5DtDyHaHkO0PIdoeQ6MjFrfyI2YiwdUJEEp0oqEIE
LR4jM0Uy4OyS71WW/br3e/E2ym02gnn1TIyHp9+na3+5zmREU4EO1FT3quXGVJkOrrKzPwD7c6JrFUDv
cSPv+qWvENfrUKaRdztzkTvzBcaeXLZs4ujP2+XaOR/xCr1jXzWX42xZ+FmR+DE9b/V6cUafM/2Lf3tm
cV4D8HCzbuAAAAAAAAAaqPDJHBFNSOgapE6jHrN3AvF1hR9q7yS/l1V3ZG0f9e1HQQlTPqXA/bL90lB/
1yziK6sN8qNxwk0u/pLsvIlOFCAqxq/Ssch4eRfUYzPp3FJGmv+l624bbgKszv8pFJDyYH5FwieqWvw9
FhDv5DqTH9S87fhIdvaUJp0f7tdezt4OwTYezt4OwTYezt4OwTYezt4OwTYezt4OwTYezt4OwTYezt4O
wTYezt4OwTYezt4OwTYezt4OwTYezt4OwTYezt4OwTYezt4Oo52DNK2p5NOHw97HqhAUpu2B4CR86L/0
apjrDy5rSQ7AoaGMy1GSW1TSU6DDrzhMnmclrnUde2fJ7c1NEFiG+OMKa0VEXi7haVoAhB1XJA7cbP2i
ycYIkSGCIL6Mo5IzYk3d+2Qnul45XtKCsrCSoZ9bLH80meVb+/keINi2r0gv3t0r7o3CQxR9sqD3tVe9
dPklXu9dPklXu9dPklXu9dPklXu9dPklXu9dPklXu9dPklXu9dPklXu9dPklXu9dPklXu9dPklXu9dPk
lXopuevvKDbrigv6DdRP2yWfJup0Z6nt/ieDCd9mhwR7HvUldQiBwtOfB87RreLvecsUEg6pviyLUu0R
0t+BXDdzhSm1h0rYUJJO2EZaActxA+Ip/0piTo8XkWOW4gfEU/6UxJ0eLyLHLcQPgigaX062kXMJD93K
Erxf5uD2aWbntTNJBjNQCfp7FxJkkQkWCW3QISo6IUk1R0DSnatm/Mm9Q7X6MAgQ+8N7eu7B3oJm2tra
2tra2tra2tra2tra2tra2tra2tra2tra2tra2tra2tra2tra2tra2tra2tra2tra2QgAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
"""

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
vae.to("cuda")
with io.BytesIO(base64.b64decode(img_b64)) as f:
    image = Image.open(f)
wht = Image.new("RGBA", image.size, (255, 255, 255, 255))
image = Image.alpha_composite(wht, image).convert("RGB")

with torch.no_grad():
    vae_in = torch.from_numpy(np.array(image).transpose(2, 0, 1) / 255.0).unsqueeze(0) * 2 - 1
    vae_in = vae_in.to(device=vae.device, dtype=vae.dtype)

    vae_out = vae.encode(vae_in).latent_dist.sample()
    assert ~torch.isnan(vae_out).any(), "NaN in latent"

    vae.decoder = oneflow_compile(vae.decoder)
    vae.encoder = oneflow_compile(vae.encoder)

    vae_out = vae.encode(vae_in).latent_dist.sample()
    assert ~torch.isnan(vae_out).any(), "NaN in latent (after compile)"

The complete error message

AssertionError: NaN in latent (after compile)

Additional context

Expected output should be no asserts at all.