-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Corruption with Wuerstchen and Stable Cascade models #529
Comments
Did some more experimenting. Prior model woks fine with IPEX 2.0 but Decoder model fails. IPEX 2.0: IPEX 2.1: Here is the ipynb file i tested: Preview modules are taken from here: https://huggingface.co/spaces/multimodalart/stable-cascade/tree/main/previewer {
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0a186e03-0171-41b5-8d50-13cba3333e41",
"metadata": {},
"outputs": [],
"source": [
"#pip install --force-reinstall torch==2.1.0a0 torchvision==0.16.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e254e5fc-8857-4528-89be-69813e177b36",
"metadata": {},
"outputs": [],
"source": [
"#pip install --force-reinstall tensorboard==2.14.1 tensorflow==2.14.0 intel-extension-for-tensorflow[xpu]==2.14.0.1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e073708-9e78-4149-a441-7c46bf10b67a",
"metadata": {},
"outputs": [],
"source": [
"#pip install git+https://github.com/kashif/diffusers.git@wuerstchen-v3 accelerate transformers typing_extensions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1873010a-0a0c-4c1c-bc08-02c87f3f39d1",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import torch\n",
"import intel_extension_for_pytorch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af18ccef-8328-4383-839f-10df6d2d73e0",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline\n",
"prior = StableCascadePriorPipeline.from_pretrained(\"stabilityai/stable-cascade-prior\", torch_dtype=torch.bfloat16).to(\"xpu\")\n",
"decoder = StableCascadeDecoderPipeline.from_pretrained(\"stabilityai/stable-cascade\", torch_dtype=torch.bfloat16).to(\"xpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38fd9e08-3ad3-4efe-b271-c70865517373",
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import clear_output\n",
"from diffusers.utils import numpy_to_pil\n",
"from previewer import Previewer\n",
"previewer = Previewer()\n",
"previewer_state_dict = torch.load(\"previewer_v1_100k.pt\", map_location=torch.device('cpu'))[\"state_dict\"]\n",
"previewer.load_state_dict(previewer_state_dict)\n",
"previewer = previewer.eval().requires_grad_(False).to(\"xpu\", dtype=torch.bfloat16)\n",
"def callback_prior(i, t, latents):\n",
" output = previewer(latents)\n",
" output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())\n",
" clear_output()\n",
" display(output[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ce30cca-6dd0-48bd-8dba-4ef30c94c30f",
"metadata": {},
"outputs": [],
"source": [
"num_images_per_prompt = 1\n",
"callback_steps = 1\n",
"prompt = \"Anthropomorphic cat dressed as a pilot\"\n",
"negative_prompt = \"\"\n",
"\n",
"torch.xpu.empty_cache()\n",
"prior_output = prior(\n",
" prompt=prompt,\n",
" height=1024,\n",
" width=1024,\n",
" negative_prompt=negative_prompt,\n",
" guidance_scale=4.0,\n",
" num_images_per_prompt=num_images_per_prompt,\n",
" num_inference_steps=20,\n",
" callback=callback_prior,\n",
" callback_steps=callback_steps,\n",
")\n",
"torch.xpu.empty_cache()\n",
"decoder_output = decoder(\n",
" image_embeddings=prior_output.image_embeddings,\n",
" prompt=prompt,\n",
" negative_prompt=negative_prompt,\n",
" guidance_scale=0.0,\n",
" output_type=\"pil\",\n",
" num_inference_steps=10\n",
").images\n",
"torch.xpu.empty_cache()\n",
"\n",
"display(decoder_output[0])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
} |
@Disty0 I will try reproducing your issue on Arc |
seeing same corruption on A770 |
This one still happens with IPEX 2.1.20+xpu. |
@Disty0 let's focus on wuerstchen first since I have it on my setup with an Arc A770 and the issue should be similar to stablecascade. I did have to use "warp-diffusion/wuerstchen" as the model card instead of from warp-ai, but I get the image corruption on both IPEX v2.1.10+xpu and v2.0.120+xpu. Can you show me the image you get without corruption on v2.0.120+xpu? We should also note the resolution of the images. It seems the outputs are 1024x1024. |
Final image is corrupted on all of them. Latent previewer for prior stage: https://huggingface.co/spaces/multimodalart/stable-cascade/tree/main/previewer Also this CPU fallback patch to original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
if antialias or align_corners is not None or mode == 'bicubic':
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
else:
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
torch.nn.functional.interpolate = interpolate
This happens on any resolution so i just used the default one for the report. |
I also tried this on CPU and got a clear HD image. Will focus debug on the prior latent stage outputs, and compare the CPU vs GPU values. |
I'm working with the team to find a simpler reproducer to identify what ops are causing the corruption. I did try this on CPU as well and saw an HD image from it, so the op is specific to the GPU. |
I'm able to narrow down the issue more. When the latents are running through the denoising loop, the latents will become NaN values at random. Once it contains NaN, all subsequent iterations of denoising will result in the latents being NaNs. Confirmed that this occurs on the GPU only and not on the CPU. Here's an example of the denoising loop for wuerstchen: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py#L375 Will continue to dive deeper into what operation is causing the NaN. |
The NaN is further narrowed down into the Wuerstchen_DiffNeXt's self._up_decode function inside the if statement with ResBlockStageB. Note that there is a call to torch.nn.functional.interpolate, and the developer modified the interpolate function as shown at the top of this ticket, since XPU does not support bicubic. |
I know that there hasn't been activity on this front for a while, not sure if things got sidelined or etc. but in the most recent IPEX release of v2.1.40+xpu, the corruption is still occurring and hasn't been fixed. |
Still happens on IPEX 2.3 |
Stable Cascade works fine on PyTorch 2.5 XPU from PyTorch test branch. |
But still getting random NaNs that doesn't happen on CPU or other GPU vendors. |
The issue has been isolated to a specific operator and have zeroed in on the issue. We will have a fix soon. |
Describe the bug
Wuerstchen and Wuerstchen based Stable Cascade models generates corrupted images.
Might be related to my old corruption issue (#519) but this one happens with any resolution and happens with GPU Max too.
Example of the corruption:
Wuerstchen: https://huggingface.co/warp-ai/wuerstchen
Stable Cascade: https://huggingface.co/stabilityai/stable-cascade
Versions
The text was updated successfully, but these errors were encountered: