huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.02k stars 5.35k forks source link

Loading models from cache extra slow due to extra conversion #5460

Closed joe-chiu closed 1 year ago

joe-chiu commented 1 year ago

Describe the bug

I have 2 Python environments, one on Windows and another on Linux (over WSL), both using diffusers. To avoid having mutliple copies of the same model on disk, I try to make these two installations share a single diffusers model cache.

  1. Ubuntu on WSL, Python 3.10.12, cache is a symlink pointing to NTFS disk, ie. the hub folder is pointing to /mnt/c/python/StableDiffusion/hub
  2. Windows 11 Home, Python 3.11.6, cache resides on a folder in NTFS disk, ie. HF_HOME=C:\python\StableDiffusion

I am running the same python code in a Jupyter notebook that just loads a SDXL model using diffusers. On the Windows installation, I will see the progress bar finishes quickly in 1-2 seconds. On the Linux, the same progress bar would finish just as quickly, but there would be another 2 minutes of execution until the cell finishes. I manually interrupted the script and I could see some sort of additional data format conversion code was running.

I tried 2 additional set up:

  1. I cleared the shared cache, and let the Ubuntu installation re-download and re-create the cache (cache still resides on a symlink pointing to a NTFS folder) => still see the additional slow down after the model is fully downloaded. Windows installation could still function perfectly with cache created by Ubuntu, no conversion delay.
  2. I stop using symlink to NTFS folders and let the Ubuntu installation uses its own cache on its own file system => no more slow down. But I now have 2 copies of each models.
from diffusers import StableDiffusionXLPipeline, AutoencoderKL, DPMSolverMultistepScheduler
import torch

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    vae=vae,
    torch_dtype=torch.float16, variant="fp16",
    use_safetensors=True
).to("cuda")

This is the traceback for the slow model loading on Linux, I can see some sort of model data conversation seems to be executing when I interrupted the execution a minute in.

---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[3], line 10
      2 import torch
      4 vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
      5 pipe = StableDiffusionXLPipeline.from_pretrained(
      6     "stabilityai/stable-diffusion-xl-base-1.0",
      7     vae=vae,
      8     torch_dtype=torch.float16, variant="fp16",
      9     use_safetensors=True
---> 10 ).to("cuda")

File ~/stable_diffusion_venv/lib/python3.10/site-packages/diffusers/pipelines/pipeline_utils.py:733, in DiffusionPipeline.to(self, torch_device, torch_dtype, silence_dtype_warnings)
    729     logger.warning(
    730         f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
    731     )
    732 else:
--> 733     module.to(torch_device, torch_dtype)
    735 if (
    736     module.dtype == torch.float16
    737     and str(torch_device) in ["cpu"]
    738     and not silence_dtype_warnings
    739     and not is_offloaded
    740 ):
    741     logger.warning(
    742         "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
    743         " is not recommended to move them to `cpu` as running them will fail. Please make"
   (...)
    746         " `torch_dtype=torch.float16` argument, or use another device for inference."
    747     )

File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1160, in Module.to(self, *args, **kwargs)
   1156         return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
   1157                     non_blocking, memory_format=convert_to_format)
   1158     return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
-> 1160 return self._apply(convert)

File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:810, in Module._apply(self, fn, recurse)
    808 if recurse:
    809     for module in self.children():
--> 810         module._apply(fn)
    812 def compute_should_use_set_data(tensor, tensor_applied):
    813     if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    814         # If the new tensor has compatible tensor type as the existing tensor,
    815         # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    820         # global flag to let the user control whether they want the future
    821         # behavior of overwriting the existing tensor or not.

File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:810, in Module._apply(self, fn, recurse)
    808 if recurse:
    809     for module in self.children():
--> 810         module._apply(fn)
    812 def compute_should_use_set_data(tensor, tensor_applied):
    813     if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    814         # If the new tensor has compatible tensor type as the existing tensor,
    815         # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    820         # global flag to let the user control whether they want the future
    821         # behavior of overwriting the existing tensor or not.

    [... skipping similar frames: Module._apply at line 810 (5 times)]

File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:810, in Module._apply(self, fn, recurse)
    808 if recurse:
    809     for module in self.children():
--> 810         module._apply(fn)
    812 def compute_should_use_set_data(tensor, tensor_applied):
    813     if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    814         # If the new tensor has compatible tensor type as the existing tensor,
    815         # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    820         # global flag to let the user control whether they want the future
    821         # behavior of overwriting the existing tensor or not.

File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:833, in Module._apply(self, fn, recurse)
    829 # Tensors stored in modules are graph leaves, and we don't want to
    830 # track autograd history of `param_applied`, so we have to use
    831 # `with torch.no_grad():`
    832 with torch.no_grad():
--> 833     param_applied = fn(param)
    834 should_use_set_data = compute_should_use_set_data(param, param_applied)
    835 if should_use_set_data:

File ~/stable_diffusion_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1158, in Module.to.<locals>.convert(t)
   1155 if convert_to_format is not None and t.dim() in (4, 5):
   1156     return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
   1157                 non_blocking, memory_format=convert_to_format)
-> 1158 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)

KeyboardInterrupt: 

Reproduction

  1. Have an Linux running on Windows WSL
  2. Use diffuers on Windows to load and cache a diffusion model (say, Windows HF_HOME is c:\somewhere, cache in c:\somewhere\hub)
  3. Use diffuers on Linux (on WSL) to load the same model, point Linux HF_HOME to the root of the VirtualEnv and the hub folder within HF_HOME is a symlink to a folder on NTFS (eg. /mnt/c/somewhere/hub)
  4. Diffusers should load the model from cache

Expect - model loads quickly from cache

Actual - some unexpected data conversion happens and delays model loading by ~2 minutes

Logs

No response

System Info

pip freeze on Linux

accelerate==0.23.0 anyio==4.0.0 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 asttokens==2.4.0 async-lru==2.0.4 attrs==23.1.0 Babel==2.13.0 backcall==0.2.0 beautifulsoup4==4.12.2 bleach==6.1.0 certifi==2023.7.22 cffi==1.16.0 charset-normalizer==3.3.0 comm==0.1.4 debugpy==1.8.0 decorator==5.1.1 defusedxml==0.7.1 diffusers==0.21.4 exceptiongroup==1.1.3 executing==2.0.0 fastjsonschema==2.18.1 filelock==3.12.4 fqdn==1.5.1 fsspec==2023.9.2 huggingface-hub==0.17.3 idna==3.4 importlib-metadata==6.8.0 ipykernel==6.25.2 ipython==8.16.1 ipywidgets==8.1.1 isoduration==20.11.0 jedi==0.19.1 Jinja2==3.1.2 json5==0.9.14 jsonpointer==2.4 jsonschema==4.19.1 jsonschema-specifications==2023.7.1 jupyter-events==0.8.0 jupyter-lsp==2.2.0 jupyter_client==8.4.0 jupyter_core==5.4.0 jupyter_server==2.8.0 jupyter_server_terminals==0.4.4 jupyterlab==4.0.7 jupyterlab-pygments==0.2.2 jupyterlab-widgets==3.0.9 jupyterlab_server==2.25.0 MarkupSafe==2.1.3 matplotlib-inline==0.1.6 mistune==3.0.2 mpmath==1.3.0 nbclient==0.8.0 nbconvert==7.9.2 nbformat==5.9.2 nest-asyncio==1.5.8 networkx==3.2 notebook_shim==0.2.3 numpy==1.26.1 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.18.1 nvidia-nvjitlink-cu12==12.3.52 nvidia-nvtx-cu12==12.1.105 overrides==7.4.0 packaging==23.2 pandocfilters==1.5.0 parso==0.8.3 pexpect==4.8.0 pickleshare==0.7.5 Pillow==10.1.0 platformdirs==3.11.0 prometheus-client==0.17.1 prompt-toolkit==3.0.39 psutil==5.9.6 ptyprocess==0.7.0 pure-eval==0.2.2 pycparser==2.21 Pygments==2.16.1 python-dateutil==2.8.2 python-json-logger==2.0.7 PyYAML==6.0.1 pyzmq==25.1.1 referencing==0.30.2 regex==2023.10.3 requests==2.31.0 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rpds-py==0.10.6 safetensors==0.4.0 Send2Trash==1.8.2 six==1.16.0 sniffio==1.3.0 soupsieve==2.5 stack-data==0.6.3 sympy==1.12 terminado==0.17.1 tinycss2==1.2.1 tokenizers==0.14.1 tomli==2.0.1 torch==2.1.0 torchaudio==2.1.0 torchvision==0.16.0 tornado==6.3.3 tqdm==4.66.1 traitlets==5.11.2 transformers==4.34.1 triton==2.1.0 types-python-dateutil==2.8.19.14 typing_extensions==4.8.0 uri-template==1.3.0 urllib3==2.0.7 wcwidth==0.2.8 webcolors==1.13 webencodings==0.5.1 websocket-client==1.6.4 widgetsnbextension==4.0.9 zipp==3.17.0

pip freeze on Windows

absl-py==2.0.0 accelerate==0.23.0 antlr4-python3-runtime==4.9.3 anyio==4.0.0 appdirs==1.4.4 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.2.3 asttokens==2.4.0 async-lru==2.0.4 attrs==23.1.0 Babel==2.12.1 backcall==0.2.0 beautifulsoup4==4.12.2 bitsandbytes==0.41.1 bleach==6.0.0 cachetools==5.3.1 certifi==2022.12.7 cffi==1.16.0 charset-normalizer==2.1.1 click==8.1.7 colorama==0.4.6 comm==0.1.4 debugpy==1.8.0 decorator==5.1.1 defusedxml==0.7.1 diffusers==0.21.4 docker-pycreds==0.4.0 executing==1.2.0 fastjsonschema==2.18.0 filelock==3.9.0 fqdn==1.5.1 fsspec==2023.9.2 ftfy==6.1.1 gitdb==4.0.10 GitPython==3.1.37 google-auth==2.23.2 google-auth-oauthlib==1.0.0 grpcio==1.59.0 huggingface-hub==0.17.3 idna==3.4 importlib-metadata==6.8.0 ipykernel==6.25.2 ipython==8.15.0 ipython-genutils==0.2.0 ipywidgets==8.1.1 isoduration==20.11.0 jedi==0.19.0 Jinja2==3.1.2 json5==0.9.14 jsonpointer==2.4 jsonschema==4.19.1 jsonschema-specifications==2023.7.1 jupyter-events==0.7.0 jupyter-lsp==2.2.0 jupyter_client==8.3.1 jupyter_core==5.3.2 jupyter_server==2.7.3 jupyter_server_terminals==0.4.4 jupyterlab==4.0.6 jupyterlab-pygments==0.2.2 jupyterlab-widgets==3.0.9 jupyterlab_server==2.25.0 Markdown==3.5 MarkupSafe==2.1.2 matplotlib-inline==0.1.6 mistune==3.0.1 mpmath==1.3.0 nbclient==0.8.0 nbconvert==7.8.0 nbformat==5.9.2 nest-asyncio==1.5.8 networkx==3.0 notebook==7.0.4 notebook_shim==0.2.3 numpy==1.24.1 nvidia-cublas-cu12==12.2.5.6 nvidia-cuda-nvrtc-cu12==12.2.140 nvidia-cuda-runtime-cu12==12.2.140 nvidia-cudnn-cu12==8.9.4.25 oauthlib==3.2.2 omegaconf==2.3.0 opencv-python==4.8.1.78 overrides==7.4.0 packaging==23.1 pandocfilters==1.5.0 parso==0.8.3 pathtools==0.1.2 pickleshare==0.7.5 Pillow==9.3.0 platformdirs==3.10.0 prometheus-client==0.17.1 prompt-toolkit==3.0.39 protobuf==4.24.4 psutil==5.9.5 pure-eval==0.2.2 pyasn1==0.5.0 pyasn1-modules==0.3.0 pycparser==2.21 Pygments==2.16.1 python-dateutil==2.8.2 python-json-logger==2.0.7 pywin32==306 pywinpty==2.0.11 PyYAML==6.0.1 pyzmq==25.1.1 qtconsole==5.4.4 QtPy==2.4.0 referencing==0.30.2 regex==2023.8.8 requests==2.31.0 requests-oauthlib==1.3.1 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rpds-py==0.10.3 rsa==4.9 safetensors==0.3.3 scipy==1.11.3 Send2Trash==1.8.2 sentry-sdk==1.31.0 setproctitle==1.3.3 six==1.16.0 smmap==5.0.1 sniffio==1.3.0 soupsieve==2.5 stack-data==0.6.2 sympy==1.12 tensorboard==2.14.1 tensorboard-data-server==0.7.1 tensorrt-bindings==9.0.1.post12.dev4 tensorrt-libs==9.0.1.post12.dev4 terminado==0.17.1 tinycss2==1.2.1 tokenizers==0.13.3 torch==2.0.1+cu118 torchaudio==2.0.2+cu118 torchvision==0.15.2+cu118 tornado==6.3.3 tqdm==4.66.1 traitlets==5.10.1 transformers==4.33.3 typing_extensions==4.4.0 uri-template==1.3.0 urllib3==1.26.13 wandb==0.15.12 wcwidth==0.2.7 webcolors==1.13 webencodings==0.5.1 websocket-client==1.6.3 Werkzeug==3.0.0 widgetsnbextension==4.0.9 xformers==0.0.22 zipp==3.17.0

Who can help?

@sayakpaul @patrickvonplaten

sayakpaul commented 1 year ago

Thanks for the thread.

To me this seems like a very specific case of loading a pipeline. Our testing setup covers for a couple types but not the ones you mentioned.

I don't think we have the bandwidth to check for WSL at the moment since none of the maintainers of the repository including myself uses it. I am still ccing @DN6 to check if he has to suggest anything additional.

joe-chiu commented 1 year ago

I tried to replicate the problem between 2 virtual Python environments in the same Linux instance. I have specified HF_HOME for both virtual environments to be pointing to their respective $VIRTUAL_ENV, so they would each look for hub folder within the virtual environments. And venv1 has the "real" cache folders and venv2 symlink has its hub folder symlink to venv1’s hub folder. And the problem does not repro this setup. It seems the problem only happens when the cache is symlink to NTFS folder under the /mnt mount point. Does the cache/model loading code use any file metadata outside of the file content itself?  

Sent from Mail for Windows 

From: Sayak PaulSent: Thursday, October 19, 2023 6:41 PM To: huggingface/diffusersCc: joe-chiu; Author Subject: Re: [huggingface/diffusers] Loading models from cache extra slow due to extra conversion (Issue #5460) 

Do you observe similar behaviour in a Linux environment?

DN6 commented 1 year ago

Hi @joe-chiu. Like @sayakpaul mentioned, this issue seems like a very specific edge case which I do not think we currently have the bandwidth to resolve.

Linking you to the HF Hub documentation here to provide some guidance on the HF cache https://huggingface.co/docs/huggingface_hub/guides/manage-cache

joe-chiu commented 1 year ago

No worries, thank you for taking the time to consider.