huggingface / diffusion-fast

Faster generation with text-to-image diffusion models.
https://pytorch.org/blog/accelerating-generative-ai-3/
Apache License 2.0
193 stars 13 forks source link

lora support for optimization #11

Closed Sandeep-Narahari closed 8 months ago

Sandeep-Narahari commented 8 months ago

I was using the torch compile optimization for speeding the inference time

Here I am using the dreambooth lora model which was trained on juggernut

when making the inference its not compiling

pipe.load_lora_weights(prj_path, weight_name="pytorch_lora_weights.safetensors")

is there any way so that I can able to use this optimzation for dreambooth lora models

packages I am using

Package Version


absl-py 2.1.0
accelerate 0.26.1
aiofiles 23.2.1
aiohttp 3.9.3
aiosignal 1.3.1
albumentations 1.3.1
alembic 1.13.1
altair 5.2.0
annotated-types 0.6.0
anyio 3.7.1
arrow 1.3.0
async-timeout 4.0.3
attrs 23.2.0
Authlib 1.3.0
autotrain-advanced 0.6.92
bitsandbytes 0.42.0
Brotli 1.1.0
cachetools 5.3.2
certifi 2023.11.17
cffi 1.16.0
charset-normalizer 3.3.2
click 8.1.7
cmaes 0.10.0
cmake 3.28.3
codecarbon 2.2.3
colorlog 6.8.2
contourpy 1.1.1
cryptography 42.0.3
cycler 0.12.1
datasets 2.14.7
diffusers 0.21.4
dill 0.3.8
docstring-parser 0.15
einops 0.6.1
evaluate 0.3.0
exceptiongroup 1.2.0
fastapi 0.104.1
ffmpy 0.3.1
filelock 3.13.1
fonttools 4.47.2
frozenlist 1.4.1
fsspec 2023.10.0
fuzzywuzzy 0.18.0
google-auth 2.27.0
google-auth-oauthlib 1.0.0
GPUtil 1.4.0
gradio 3.41.0
gradio-client 0.5.0
greenlet 3.0.3
grpcio 1.60.0
h11 0.14.0
hf-transfer 0.1.5
httpcore 1.0.2
httpx 0.26.0
huggingface-hub 0.20.3
idna 3.6
imageio 2.33.1
importlib-metadata 7.0.1
importlib-resources 6.1.1
inflate64 1.0.0
install 1.3.5
invisible-watermark 0.2.0
ipadic 1.0.0
itsdangerous 2.1.2
Jinja2 3.1.3
jiwer 3.0.2
joblib 1.3.1
jsonschema 4.21.1
jsonschema-specifications 2023.12.1
kiwisolver 1.4.5
lazy-loader 0.3
loguru 0.7.0
Mako 1.3.2
Markdown 3.5.2
markdown-it-py 3.0.0
MarkupSafe 2.1.4
matplotlib 3.7.4
mdurl 0.1.2
mpmath 1.3.0
multidict 6.0.4
multiprocess 0.70.16
multivolumefile 0.2.3
networkx 3.1
nltk 3.8.1
numpy 1.24.4
nvidia-cublas-cu11 11.11.3.6
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu11 11.8.87
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu11 11.8.89
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu11 11.8.89
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu11 8.7.0.84
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu11 10.9.0.58
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu11 10.3.0.86
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu11 11.4.1.48
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu11 11.7.5.86
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu11 2.19.3
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.3.101
nvidia-nvtx-cu11 11.8.86
nvidia-nvtx-cu12 12.1.105
oauthlib 3.2.2
opencv-python 4.9.0.80
opencv-python-headless 4.9.0.80
optuna 3.3.0
orjson 3.9.12
packaging 23.1
pandas 2.0.3
peft 0.8.2
Pillow 10.0.0
pip 20.0.2
pkg-resources 0.0.0
pkgutil-resolve-name 1.3.10
protobuf 4.23.4
psutil 5.9.8
py-cpuinfo 9.0.0
py7zr 0.20.6
pyarrow 15.0.0
pyarrow-hotfix 0.6
pyasn1 0.5.1
pyasn1-modules 0.3.0
pybcj 1.0.2
pycparser 2.21
pycryptodomex 3.20.0
pydantic 2.4.2
pydantic-core 2.10.1
pydub 0.25.1
pygments 2.17.2
pyngrok 7.0.3
pynvml 11.5.0
pyparsing 3.1.1
pyppmd 1.0.0
python-dateutil 2.8.2
python-dotenv 1.0.1
python-multipart 0.0.6
pytorch-triton 3.0.0+901819d2b6
pytz 2023.4
PyWavelets 1.4.1
PyYAML 6.0.1
pyzstd 0.15.9
qudida 0.0.4
rapidfuzz 2.13.7
referencing 0.33.0
regex 2023.12.25
requests 2.31.0
requests-oauthlib 1.3.1
responses 0.18.0
rich 13.7.0
rouge-score 0.1.2
rpds-py 0.17.1
rsa 4.9
sacremoses 0.0.53
safetensors 0.4.2
scikit-image 0.21.0
scikit-learn 1.3.0
scipy 1.10.1
semantic-version 2.10.0
sentencepiece 0.1.99
setuptools 44.0.0
shtab 1.6.5
six 1.16.0
sniffio 1.3.0
SQLAlchemy 2.0.25
starlette 0.27.0
sympy 1.12
tensorboard 2.14.0
tensorboard-data-server 0.7.2
texttable 1.7.0
threadpoolctl 3.2.0
tifffile 2023.7.10
tiktoken 0.5.1
tokenizers 0.15.1
toolz 0.12.1
torch 2.3.0.dev20240221+cu118 torchaudio 2.2.0+cu118
torchtriton 2.0.0+f16138d447
torchvision 0.17.0
tqdm 4.65.0
transformers 4.37.0
triton 2.2.0
trl 0.7.11
types-python-dateutil 2.8.19.20240106
typing-extensions 4.9.0
tyro 0.7.0
tzdata 2023.4
urllib3 2.2.0
uvicorn 0.22.0
websockets 11.0.3
Werkzeug 2.3.6
wheel 0.34.2
xformers 0.0.24
xgboost 1.7.6
xxhash 3.4.1
yarl 1.9.4
zipp 3.17.0

image

sayakpaul commented 8 months ago

You should call fuse_lora() after loading the LoRA checkpoint. And then call compile.

sayakpaul commented 8 months ago

Closing it?

Sandeep-Narahari commented 8 months ago

Awesome working fne now

Thanks