huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.62k stars 5.3k forks source link

StableDiffusionPipeline __call__ parameters format #4252

Closed junhwanlee2316 closed 1 year ago

junhwanlee2316 commented 1 year ago

I am struggling to implement the negative_prompt_embeds parameter to the pipeline for calling.

I tried to convert the pytorch file(.pt) negative_embedding to the torch.FloatTensor using torch.load. However, I am running into the error "RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 768 but got size 77 for tensor number 1 in the list."

Figured that the transformed tensor has a size of [67, 768].

Link to the pipeline document: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.negative_prompt_embeds

Code:

from auth_token import auth_token
from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from io import BytesIO
import base64 

from safetensors import safe_open
import os
from transformers import AutoTokenizer, AutoModelForCausalLM

app = FastAPI()

app.add_middleware(
    CORSMiddleware, 
    allow_credentials=True, 
    allow_origins=["*"], 
    allow_methods=["*"], 
    allow_headers=["*"]
)

device = "cuda"
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token)
pipe.to(device)

file_path = r"C:\ReactProjects\stable_diffusion_api\assets\embeddings\FastNegativeV2.pt"
FastNegativeV2 = torch.load(file_path)

embeddings_tensor = FastNegativeV2["string_to_param"]["*"][:67]

@app.get("/")
def generate(prompt: str): 
    expanded_embeddings_tensor = embeddings_tensor.expand(768, -1, -1)
    with autocast(device): 
        image = pipe(prompt = prompt, negative_prompt_embeds= expanded_embeddings_tensor, guidance_scale=8.5).images[0]

    image.save("testimage.png")
    buffer = BytesIO()
    image.save(buffer, format="PNG")
    imgstr = base64.b64encode(buffer.getvalue())

    return Response(content=imgstr, media_type="image/png")

Log

INFO:     Application startup complete.
INFO:     127.0.0.1:50276 - "GET /docs HTTP/1.1" 200 OK
INFO:     127.0.0.1:50276 - "GET /openapi.json HTTP/1.1" 200 OK
INFO:     127.0.0.1:50276 - "GET /?prompt=girl HTTP/1.1" 500 Internal Server Error
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\uvicorn\protocols\http\h11_impl.py", line 408, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\uvicorn\middleware\proxy_headers.py", line 84, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\fastapi\applications.py", line 289, in __call__
    await super().__call__(scope, receive, send)
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\starlette\applications.py", line 122, in __call__
    await self.middleware_stack(scope, receive, send)
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\starlette\middleware\errors.py", line 184, in __call__
    raise exc
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\starlette\middleware\errors.py", line 162, in __call__
    await self.app(scope, receive, _send)
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\starlette\middleware\cors.py", line 83, in __call__
    await self.app(scope, receive, send)
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\starlette\middleware\exceptions.py", line 79, in __call__
    raise exc
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\starlette\middleware\exceptions.py", line 68, in __call__
    await self.app(scope, receive, sender)
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\fastapi\middleware\asyncexitstack.py", line 20, in __call__
    raise e
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\fastapi\middleware\asyncexitstack.py", line 17, in __call__
    await self.app(scope, receive, send)
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\starlette\routing.py", line 718, in __call__
    await route.handle(scope, receive, send)
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\starlette\routing.py", line 276, in handle
    await self.app(scope, receive, send)
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\starlette\routing.py", line 66, in app
    response = await func(request)
               ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\fastapi\routing.py", line 273, in app
    raw_response = await run_endpoint_function(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\fastapi\routing.py", line 192, in run_endpoint_function
    return await run_in_threadpool(dependant.call, **values)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\starlette\concurrency.py", line 41, in run_in_threadpool
    return await anyio.to_thread.run_sync(func, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\anyio\to_thread.py", line 33, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\anyio\_backends\_asyncio.py", line 877, in run_sync_in_worker_thread
    return await future
           ^^^^^^^^^^^^
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\anyio\_backends\_asyncio.py", line 807, in run
    result = context.run(func, *args)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\ReactProjects\stable_diffusion_api\api.py", line 40, in generate
    image = pipe(prompt = prompt, negative_prompt_embeds= embeddings_tensor, guidance_scale=8.5).images[0]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffusers\pipelines\stable_diffusion\pipeline_stable_diffusion.py", line 688, in __call__
    prompt_embeds = self._encode_prompt(
                    ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\junhw\AppData\Local\Programs\Python\Python311\Lib\site-packages\diffusers\pipelines\stable_diffusion\pipeline_stable_diffusion.py", line 452, in _encode_prompt
    prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 768 but got size 77 for tensor number 1 in the list.
huibing99 commented 1 year ago

Hi, I am also working on a similar project. I manually reshaped negative_prompt_embeds to (1, 77, 768), and although it runs without errors, the results are not satisfactory. I would like to know the correct usage.

patrickvonplaten commented 1 year ago

Please make sure to load the negative vector with load_textual_inversion as explained here: https://github.com/huggingface/diffusers/issues/3198