xhinker / sd_embed

Generate long weighted prompt embeddings for Stable Diffusion
Apache License 2.0
67 stars 7 forks source link

Support for pipe.enable_model_cpu_offload() and pipe.enable_sequential_cpu_offload() #2

Open Teriks opened 3 months ago

Teriks commented 3 months ago

Is there a way to support pipelines with CPU offloading enabled?

It seems currently unable to handle this condition

import gc
import torch
from diffusers import StableDiffusion3Pipeline
from dgenerate.extras.sd_embed.embedding_funcs import get_weighted_text_embeddings_sd3

model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
pipe = StableDiffusion3Pipeline.from_pretrained(
    model_path,
    torch_dtype=torch.float16
)

pipe.enable_sequential_cpu_offload(device='cuda')

prompt = """A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus. 
This imaginative creature features the distinctive, bulky body of a hippo, 
but with a texture and appearance resembling a golden-brown, crispy waffle. 
The creature might have elements like waffle squares across its skin and a syrup-like sheen. 
It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting, 
possibly including oversized utensils or plates in the background. 
The image should evoke a sense of playful absurdity and culinary fantasy.
"""

neg_prompt = """\
skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
(normal quality:2),lowres,((monochrome)),((grayscale))
"""

(
    prompt_embeds
    , prompt_neg_embeds
    , pooled_prompt_embeds
    , negative_pooled_prompt_embeds
) = get_weighted_text_embeddings_sd3(
    pipe
    , prompt = prompt
    , neg_prompt = neg_prompt
)

image = pipe(
    prompt_embeds                   = prompt_embeds
    , negative_prompt_embeds        = prompt_neg_embeds
    , pooled_prompt_embeds          = pooled_prompt_embeds
    , negative_pooled_prompt_embeds = negative_pooled_prompt_embeds
    , num_inference_steps           = 30
    , height                        = 1024
    , width                         = 1024 + 512
    , guidance_scale                = 4.0
    , generator                     = torch.Generator("cuda").manual_seed(2)
).images[0]

image.save('test.png')

del prompt_embeds, prompt_neg_embeds,pooled_prompt_embeds, negative_pooled_prompt_embeds
pipe.to('cpu')
gc.collect()
torch.cuda.empty_cache()

Result:

  deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
Loading pipeline components...:  44%|████▍     | 4/9 [00:00<00:00, 10.99it/s]
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  1.62it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.06it/s]
Loading pipeline components...: 100%|██████████| 9/9 [00:03<00:00,  2.44it/s]
Traceback (most recent call last):
  File "REDACT\test2.py", line 38, in <module>
    ) = get_weighted_text_embeddings_sd3(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\dgenerate\extras\sd_embed\embedding_funcs.py", line 1187, in get_weighted_text_embeddings_sd3
    prompt_embeds_1 = pipe.text_encoder(
                      ^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\accelerate\hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\transformers\models\clip\modeling_clip.py", line 1217, in forward
    text_outputs = self.text_model(
                   ^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\transformers\models\clip\modeling_clip.py", line 699, in forward
    hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\accelerate\hooks.py", line 161, in new_forward
    args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\accelerate\hooks.py", line 356, in pre_forward
    return send_to_device(args, self.execution_device), send_to_device(
                                                        ^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\accelerate\utils\operations.py", line 187, in send_to_device
    k: t if k in skip_keys else send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "REDACT\venv\Lib\site-packages\accelerate\utils\operations.py", line 158, in send_to_device
    return tensor.to(device, non_blocking=non_blocking)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!

Process finished with exit code 1
Teriks commented 3 months ago

I believe adding a device argument to the methods which build the embeddings, with which the user can specify the device the tensors are created on, would fix this. Instead of relying on pipe.device to determine the device.

Perhaps an argument where device=None defaults to the device of the pipeline pipe.device, and when specified otherwise, use the user specified device.

Unfortunately, I do not think I have enough GPU memory to fully test SD3 using this method, I can modify this code and see how it works for the other methods.

Teriks commented 3 months ago

This is working for SD3 for me with the mentioned changes

Teriks commented 3 months ago

Model offloading sometimes causes black output with this change when using stabilityai/stable-diffusion-xl-base-1.0, but not with other SDXL models I am testing. Very peculiar.

Particularly for the fp16 model variant, with dtype = float16

Perhaps there is more to do, or this is a bug elsewhere.

xhinker commented 3 months ago

Hi, @Teriks , see my comments here: https://github.com/xhinker/sd_embed/pull/3 Thank You!

xhinker commented 1 month ago

Hi, @Teriks I have added pipe.enable_model_cpu_offload() support for Flux1

Teriks commented 1 month ago

oh I see, this is probably needed elsewhere where there is a clip encoder involved for a device argument since they need to be on a GPU, I will see what happens with CPU gen later

xhinker commented 1 month ago

I figured out how to manage the VRAM usage, and now Flux1 int8/bfloat8 use only 14GB VRAM with the sd_embed, you may want to give it a try:

#%%
from diffusers import DiffusionPipeline, FluxTransformer2DModel
from torchao.quantization import quantize_, int8_weight_only
import torch
from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1

# model_path = "black-forest-labs/FLUX.1-schnell"
model_path = "/home/andrewzhu/storage_14t_5/ai_models_all/sd_hf_models/black-forest-labs/FLUX.1-dev_main"

transformer = FluxTransformer2DModel.from_pretrained(
    model_path
    , subfolder = "transformer"
    , torch_dtype = torch.bfloat16
)
quantize_(transformer, int8_weight_only())

pipe = DiffusionPipeline.from_pretrained(
    model_path
    , transformer = transformer
    , torch_dtype = torch.bfloat16
)

pipe.enable_model_cpu_offload()

#%%
prompt = """\
A dreamy, soft-focus photograph capturing a romantic Jane Austen movie scene, 
in the style of Agnes Cecile. Delicate watercolors, misty background, 
Regency-era couple, tender embrace, period clothing, flowing dress, dappled sunlight, 
ethereal glow, gentle expressions, intricate lace, muted pastels, serene countryside, 
timeless romance, poetic atmosphere, wistful mood, look at camera.
"""

prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(
    pipe        = pipe
    , prompt    = prompt
)
image = pipe(
    prompt_embeds               = prompt_embeds
    , pooled_prompt_embeds      = pooled_prompt_embeds
    , width                     = 896
    , height                    = 1280
    , num_inference_steps       = 20
    , guidance_scale            = 4.0
    , generator                 = torch.Generator().manual_seed(1234)
).images[0]
display(image)
Teriks commented 1 month ago

This model is too large for my hardware in this config (1080TI - 11GB), but it is okay since I have been able to run it using sequential offload which is good enough for testing. I really need new hardware :)

One thing I noticed about the code is that it might be good to have a user specified fallback device for the text encoders when device=cpu

In case someone has MPS or another backend available.

Torch may just auto convert the hard coded 'cuda', 'cuda:1' to some other available accelerator for compatibility but I am not sure.

If cuda is not available, but some other accelerator is available, it might fail unnecessarily.

Maybe a module global would do or another parameter.