Open Disty0 opened 9 months ago
Did some more experimenting.
Prior model woks fine with IPEX 2.0 but Decoder model fails. Both models fails with IPEX 2.1.
IPEX 2.0: (Prior Output Preview / Final Decoder Output)
IPEX 2.1: (Prior Output Preview / Final Decoder Output)
Here is the ipynb file i tested: Same results on both ARC and GPU Max.
Preview modules are taken from here: https://huggingface.co/spaces/multimodalart/stable-cascade/tree/main/previewer
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0a186e03-0171-41b5-8d50-13cba3333e41",
"metadata": {},
"outputs": [],
"source": [
"#pip install --force-reinstall torch==2.1.0a0 torchvision==0.16.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e254e5fc-8857-4528-89be-69813e177b36",
"metadata": {},
"outputs": [],
"source": [
"#pip install --force-reinstall tensorboard==2.14.1 tensorflow==2.14.0 intel-extension-for-tensorflow[xpu]==2.14.0.1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e073708-9e78-4149-a441-7c46bf10b67a",
"metadata": {},
"outputs": [],
"source": [
"#pip install git+https://github.com/kashif/diffusers.git@wuerstchen-v3 accelerate transformers typing_extensions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1873010a-0a0c-4c1c-bc08-02c87f3f39d1",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import torch\n",
"import intel_extension_for_pytorch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af18ccef-8328-4383-839f-10df6d2d73e0",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline\n",
"prior = StableCascadePriorPipeline.from_pretrained(\"stabilityai/stable-cascade-prior\", torch_dtype=torch.bfloat16).to(\"xpu\")\n",
"decoder = StableCascadeDecoderPipeline.from_pretrained(\"stabilityai/stable-cascade\", torch_dtype=torch.bfloat16).to(\"xpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38fd9e08-3ad3-4efe-b271-c70865517373",
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import clear_output\n",
"from diffusers.utils import numpy_to_pil\n",
"from previewer import Previewer\n",
"previewer = Previewer()\n",
"previewer_state_dict = torch.load(\"previewer_v1_100k.pt\", map_location=torch.device('cpu'))[\"state_dict\"]\n",
"previewer.load_state_dict(previewer_state_dict)\n",
"previewer = previewer.eval().requires_grad_(False).to(\"xpu\", dtype=torch.bfloat16)\n",
"def callback_prior(i, t, latents):\n",
" output = previewer(latents)\n",
" output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())\n",
" clear_output()\n",
" display(output[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ce30cca-6dd0-48bd-8dba-4ef30c94c30f",
"metadata": {},
"outputs": [],
"source": [
"num_images_per_prompt = 1\n",
"callback_steps = 1\n",
"prompt = \"Anthropomorphic cat dressed as a pilot\"\n",
"negative_prompt = \"\"\n",
"\n",
"torch.xpu.empty_cache()\n",
"prior_output = prior(\n",
" prompt=prompt,\n",
" height=1024,\n",
" width=1024,\n",
" negative_prompt=negative_prompt,\n",
" guidance_scale=4.0,\n",
" num_images_per_prompt=num_images_per_prompt,\n",
" num_inference_steps=20,\n",
" callback=callback_prior,\n",
" callback_steps=callback_steps,\n",
")\n",
"torch.xpu.empty_cache()\n",
"decoder_output = decoder(\n",
" image_embeddings=prior_output.image_embeddings,\n",
" prompt=prompt,\n",
" negative_prompt=negative_prompt,\n",
" guidance_scale=0.0,\n",
" output_type=\"pil\",\n",
" num_inference_steps=10\n",
").images\n",
"torch.xpu.empty_cache()\n",
"\n",
"display(decoder_output[0])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@Disty0 I will try reproducing your issue on Arc
seeing same corruption on A770
This one still happens with IPEX 2.1.20+xpu.
@Disty0 let's focus on wuerstchen first since I have it on my setup with an Arc A770 and the issue should be similar to stablecascade. I did have to use "warp-diffusion/wuerstchen" as the model card instead of from warp-ai, but I get the image corruption on both IPEX v2.1.10+xpu and v2.0.120+xpu.
Can you show me the image you get without corruption on v2.0.120+xpu? We should also note the resolution of the images. It seems the outputs are 1024x1024.
Can you show me the image you get without corruption on v2.0.120+xpu?
Final image is corrupted on all of them. v2.0.120+xpu does generate the intermediate prior stage latent outputs without corruption and gets corrupted at the decoder stage. v2.1.10+xpu is corrupt on all stages.
Latent previewer for prior stage: https://huggingface.co/spaces/multimodalart/stable-cascade/tree/main/previewer
Also this CPU fallback patch to torch.nn.functional.interpolate
is needed for Wuerstchen since XPU doesn't support bicubic.
Stable Cascade does not need this patch.
original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
if antialias or align_corners is not None or mode == 'bicubic':
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
else:
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
torch.nn.functional.interpolate = interpolate
We should also note the resolution of the images. It seems the outputs are 1024x1024.
This happens on any resolution so i just used the default one for the report. Same issue with 768x768, 768x1280, 1024x1024, 1280x1280, 1024x1536 etc.
I also tried this on CPU and got a clear HD image. Will focus debug on the prior latent stage outputs, and compare the CPU vs GPU values.
I'm working with the team to find a simpler reproducer to identify what ops are causing the corruption. I did try this on CPU as well and saw an HD image from it, so the op is specific to the GPU.
I'm able to narrow down the issue more. When the latents are running through the denoising loop, the latents will become NaN values at random. Once it contains NaN, all subsequent iterations of denoising will result in the latents being NaNs. Confirmed that this occurs on the GPU only and not on the CPU.
Here's an example of the denoising loop for wuerstchen: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py#L375
Will continue to dive deeper into what operation is causing the NaN.
The NaN is further narrowed down into the Wuerstchen_DiffNeXt's self._up_decode function inside the if statement with ResBlockStageB. Note that there is a call to torch.nn.functional.interpolate, and the developer modified the interpolate function as shown at the top of this ticket, since XPU does not support bicubic.
I know that there hasn't been activity on this front for a while, not sure if things got sidelined or etc. but in the most recent IPEX release of v2.1.40+xpu, the corruption is still occurring and hasn't been fixed.
Still happens on IPEX 2.3
Stable Cascade works fine on PyTorch 2.5 XPU from PyTorch test branch.
Stable Cascade works fine on PyTorch 2.5 XPU from PyTorch test branch.
But still getting random NaNs that doesn't happen on CPU or other GPU vendors.
The issue has been isolated to a specific operator and have zeroed in on the issue. We will have a fix soon.
Describe the bug
Wuerstchen and Wuerstchen based Stable Cascade models generates corrupted images.
Might be related to my old corruption issue (https://github.com/intel/intel-extension-for-pytorch/issues/519) but this one happens with any resolution and happens with GPU Max too.
Example of the corruption:
Wuerstchen: https://huggingface.co/warp-ai/wuerstchen
Stable Cascade: https://huggingface.co/stabilityai/stable-cascade
Versions