intel / intel-extension-for-pytorch

A Python package for extending the official PyTorch that can easily obtain performance on Intel platform
Apache License 2.0
1.61k stars 247 forks source link

Corruption with Wuerstchen and Stable Cascade models #529

Open Disty0 opened 9 months ago

Disty0 commented 9 months ago

Describe the bug

Wuerstchen and Wuerstchen based Stable Cascade models generates corrupted images.

Might be related to my old corruption issue (https://github.com/intel/intel-extension-for-pytorch/issues/519) but this one happens with any resolution and happens with GPU Max too.

Example of the corruption:

Wuerstchen: https://huggingface.co/warp-ai/wuerstchen

wuerstchen

from functools import wraps
import torch
import intel_extension_for_pytorch
from diffusers import AutoPipelineForText2Image

original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
    if antialias or align_corners is not None or mode == 'bicubic':
        return_device = tensor.device
        return_dtype = tensor.dtype
        return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
        align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
    else:
        return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
        align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
torch.nn.functional.interpolate = interpolate

device = "xpu"
dtype = torch.bfloat16

pipeline =  AutoPipelineForText2Image.from_pretrained(
    "warp-ai/wuerstchen", torch_dtype=dtype
).to(device)

caption = "Anthropomorphic cat dressed as a fire fighter"

output = pipeline(
    prompt=caption,
    height=1024,
    width=1024,
    prior_guidance_scale=4.0,
    decoder_guidance_scale=0.0,
).images
output[0].save("wuerstchen.jpg")

Stable Cascade: https://huggingface.co/stabilityai/stable-cascade

stable-cascade

import torch
import intel_extension_for_pytorch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline

# Using diffusers from git+https://github.com/kashif/diffusers.git@wuerstchen-v3

device = "xpu"
num_images_per_prompt = 1

prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade",  torch_dtype=torch.float16).to(device)

prompt = "Anthropomorphic cat dressed as a pilot"
negative_prompt = ""

prior_output = prior(
    prompt=prompt,
    height=1024,
    width=1024,
    negative_prompt=negative_prompt,
    guidance_scale=4.0,
    num_images_per_prompt=num_images_per_prompt,
    num_inference_steps=20
)
decoder_output = decoder(
    image_embeddings=prior_output.image_embeddings.half(),
    prompt=prompt,
    negative_prompt=negative_prompt,
    guidance_scale=0.0,
    output_type="pil",
    num_inference_steps=10
).images

decoder_output[0].save("stable-cascade.jpg")

Versions

Collecting environment information...
PyTorch version: 2.1.0a0+cxx11.abi
PyTorch CXX11 ABI: Yes
IPEX version: 2.1.10+xpu
IPEX commit: a12f9f650
Build type: Release

OS: Arch Linux (x86_64)
GCC version: (GCC) 13.2.1 20230801
Clang version: 16.0.6
IGC version: 2024.0.0 (2024.0.0.20231017)
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.11.7 (main, Jan 29 2024, 16:03:57) [GCC 13.2.1 20230801] (64-bit runtime)
Python platform: Linux-6.7.4-arch1-1-x86_64-with-glibc2.39
Is XPU available: True
DPCPP runtime version: 2024.0
MKL version: 2024.0
GPU models and configuration: 
[0] _DeviceProperties(name='Intel(R) Arc(TM) A770 Graphics', platform_name='Intel(R) Level-Zero', dev_type='gpu, support_fp64=0, total_memory=15473MB, max_compute_units=512, gpu_eu_count=512)
Intel OpenCL ICD version: N/A
Level Zero version: N/A

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             16
On-line CPU(s) list:                0-15
Vendor ID:                          AuthenticAMD
Model name:                         AMD Ryzen 7 5800X3D 8-Core Processor
CPU family:                         25
Model:                              33
Thread(s) per core:                 2
Core(s) per socket:                 8
Socket(s):                          1
Stepping:                           2
Frequency boost:                    enabled
CPU(s) scaling MHz:                 81%
CPU max MHz:                        4548.8281
CPU min MHz:                        2200.0000
BogoMIPS:                           6803.97
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm debug_swap
Virtualization:                     AMD-V
L1d cache:                          256 KiB (8 instances)
L1i cache:                          256 KiB (8 instances)
L2 cache:                           4 MiB (8 instances)
L3 cache:                           96 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] dctorch==0.1.2
[pip3] intel-extension-for-pytorch==2.1.10+xpu
[pip3] numpy==1.26.2
[pip3] open-clip-torch==2.24.0
[pip3] pytorch-lightning==1.9.4
[pip3] torch==2.1.0a0+cxx11.abi
[pip3] torchdiffeq==0.2.3
[pip3] torchmetrics==1.3.0.post0
[pip3] torchsde==0.2.6
[pip3] torchvision==0.16.0a0+cxx11.abi
[conda] N/A
Disty0 commented 9 months ago

Did some more experimenting.

Prior model woks fine with IPEX 2.0 but Decoder model fails. Both models fails with IPEX 2.1.

IPEX 2.0: (Prior Output Preview / Final Decoder Output)

image

IPEX 2.1: (Prior Output Preview / Final Decoder Output)

image

Here is the ipynb file i tested: Same results on both ARC and GPU Max.

Preview modules are taken from here: https://huggingface.co/spaces/multimodalart/stable-cascade/tree/main/previewer

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a186e03-0171-41b5-8d50-13cba3333e41",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pip install --force-reinstall torch==2.1.0a0 torchvision==0.16.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e254e5fc-8857-4528-89be-69813e177b36",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pip install --force-reinstall tensorboard==2.14.1 tensorflow==2.14.0 intel-extension-for-tensorflow[xpu]==2.14.0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e073708-9e78-4149-a441-7c46bf10b67a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pip install git+https://github.com/kashif/diffusers.git@wuerstchen-v3 accelerate transformers typing_extensions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1873010a-0a0c-4c1c-bc08-02c87f3f39d1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import intel_extension_for_pytorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af18ccef-8328-4383-839f-10df6d2d73e0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline\n",
    "prior = StableCascadePriorPipeline.from_pretrained(\"stabilityai/stable-cascade-prior\", torch_dtype=torch.bfloat16).to(\"xpu\")\n",
    "decoder = StableCascadeDecoderPipeline.from_pretrained(\"stabilityai/stable-cascade\",  torch_dtype=torch.bfloat16).to(\"xpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38fd9e08-3ad3-4efe-b271-c70865517373",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output\n",
    "from diffusers.utils import numpy_to_pil\n",
    "from previewer import Previewer\n",
    "previewer = Previewer()\n",
    "previewer_state_dict = torch.load(\"previewer_v1_100k.pt\", map_location=torch.device('cpu'))[\"state_dict\"]\n",
    "previewer.load_state_dict(previewer_state_dict)\n",
    "previewer = previewer.eval().requires_grad_(False).to(\"xpu\", dtype=torch.bfloat16)\n",
    "def callback_prior(i, t, latents):\n",
    "    output = previewer(latents)\n",
    "    output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())\n",
    "    clear_output()\n",
    "    display(output[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ce30cca-6dd0-48bd-8dba-4ef30c94c30f",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_images_per_prompt = 1\n",
    "callback_steps = 1\n",
    "prompt = \"Anthropomorphic cat dressed as a pilot\"\n",
    "negative_prompt = \"\"\n",
    "\n",
    "torch.xpu.empty_cache()\n",
    "prior_output = prior(\n",
    "    prompt=prompt,\n",
    "    height=1024,\n",
    "    width=1024,\n",
    "    negative_prompt=negative_prompt,\n",
    "    guidance_scale=4.0,\n",
    "    num_images_per_prompt=num_images_per_prompt,\n",
    "    num_inference_steps=20,\n",
    "    callback=callback_prior,\n",
    "    callback_steps=callback_steps,\n",
    ")\n",
    "torch.xpu.empty_cache()\n",
    "decoder_output = decoder(\n",
    "    image_embeddings=prior_output.image_embeddings,\n",
    "    prompt=prompt,\n",
    "    negative_prompt=negative_prompt,\n",
    "    guidance_scale=0.0,\n",
    "    output_type=\"pil\",\n",
    "    num_inference_steps=10\n",
    ").images\n",
    "torch.xpu.empty_cache()\n",
    "\n",
    "display(decoder_output[0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
alexsin368 commented 9 months ago

@Disty0 I will try reproducing your issue on Arc

qiacheng commented 8 months ago

seeing same corruption on A770

Disty0 commented 7 months ago

This one still happens with IPEX 2.1.20+xpu.

alexsin368 commented 7 months ago

@Disty0 let's focus on wuerstchen first since I have it on my setup with an Arc A770 and the issue should be similar to stablecascade. I did have to use "warp-diffusion/wuerstchen" as the model card instead of from warp-ai, but I get the image corruption on both IPEX v2.1.10+xpu and v2.0.120+xpu.

Can you show me the image you get without corruption on v2.0.120+xpu? We should also note the resolution of the images. It seems the outputs are 1024x1024.

Disty0 commented 7 months ago

Can you show me the image you get without corruption on v2.0.120+xpu?

Final image is corrupted on all of them. v2.0.120+xpu does generate the intermediate prior stage latent outputs without corruption and gets corrupted at the decoder stage. v2.1.10+xpu is corrupt on all stages.

Latent previewer for prior stage: https://huggingface.co/spaces/multimodalart/stable-cascade/tree/main/previewer

Also this CPU fallback patch to torch.nn.functional.interpolate is needed for Wuerstchen since XPU doesn't support bicubic. Stable Cascade does not need this patch.

original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
    if antialias or align_corners is not None or mode == 'bicubic':
        return_device = tensor.device
        return_dtype = tensor.dtype
        return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
        align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
    else:
        return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
        align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
torch.nn.functional.interpolate = interpolate

We should also note the resolution of the images. It seems the outputs are 1024x1024.

This happens on any resolution so i just used the default one for the report. Same issue with 768x768, 768x1280, 1024x1024, 1280x1280, 1024x1536 etc.

alexsin368 commented 7 months ago

I also tried this on CPU and got a clear HD image. Will focus debug on the prior latent stage outputs, and compare the CPU vs GPU values.

alexsin368 commented 6 months ago

I'm working with the team to find a simpler reproducer to identify what ops are causing the corruption. I did try this on CPU as well and saw an HD image from it, so the op is specific to the GPU.

alexsin368 commented 6 months ago

I'm able to narrow down the issue more. When the latents are running through the denoising loop, the latents will become NaN values at random. Once it contains NaN, all subsequent iterations of denoising will result in the latents being NaNs. Confirmed that this occurs on the GPU only and not on the CPU.

Here's an example of the denoising loop for wuerstchen: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py#L375

Will continue to dive deeper into what operation is causing the NaN.

alexsin368 commented 6 months ago

The NaN is further narrowed down into the Wuerstchen_DiffNeXt's self._up_decode function inside the if statement with ResBlockStageB. Note that there is a call to torch.nn.functional.interpolate, and the developer modified the interpolate function as shown at the top of this ticket, since XPU does not support bicubic.

https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py#L191C17-L205C39

simonlui commented 2 months ago

I know that there hasn't been activity on this front for a while, not sure if things got sidelined or etc. but in the most recent IPEX release of v2.1.40+xpu, the corruption is still occurring and hasn't been fixed.

Disty0 commented 2 months ago

Still happens on IPEX 2.3

Disty0 commented 3 weeks ago

Stable Cascade works fine on PyTorch 2.5 XPU from PyTorch test branch.

Disty0 commented 2 weeks ago

Stable Cascade works fine on PyTorch 2.5 XPU from PyTorch test branch.

But still getting random NaNs that doesn't happen on CPU or other GPU vendors.

srinarayan-srikanthan commented 2 weeks ago

The issue has been isolated to a specific operator and have zeroed in on the issue. We will have a fix soon.