black-forest-labs / flux

Official inference repo for FLUX.1 models
Apache License 2.0
13.83k stars 978 forks source link

Help with performance improvements (diffusers) #132

Open JakoLex opened 3 weeks ago

JakoLex commented 3 weeks ago

I am using the diffusers library with Flux-dev and Flux-schnell. I got the following script from here and modified it a bit. Are there any other performance improvements I can get out of my RTX 3090 with 24GB VRAM and 32GB RAM. I commented pipeline.to("cuda") out as just using pipeline.enable_sequential_cpu_offload() was much faster.

#%%
from diffusers import FluxPipeline, AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import gc
from random import randint
from os import environ

# %%
torch.set_float32_matmul_precision("high")

torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
torch.backends.cuda.matmul.allow_tf32 = True
environ["CUDA_MODULE_LOADING"]="LAZY"

# %%
def flush():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

def bytes_to_giga_bytes(bytes):
    return bytes / 1024 / 1024 / 1024

# %%
flush()

# %%
ckpt_id = "black-forest-labs/FLUX.1-dev"
prompt = "A scenic view of the alps"

# %%
text_encoder = CLIPTextModel.from_pretrained(
    ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
)
text_encoder_2 = T5EncoderModel.from_pretrained(
    ckpt_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)

# %%
tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder="tokenizer")
tokenizer_2 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder="tokenizer_2")

# %%
pipeline = FluxPipeline.from_pretrained(
    ckpt_id,
    text_encoder=text_encoder,
    text_encoder_2=text_encoder_2,
    tokenizer=tokenizer,
    tokenizer_2=tokenizer_2,
    transformer=None,
    vae=None,
)

# %% 
# pipeline.to("cuda")

# %%
pipeline.enable_sequential_cpu_offload()

# %%
with torch.no_grad():
    print("Encoding prompts.")
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
        prompt=prompt, prompt_2=None, max_sequence_length=256
    )

# %%
del text_encoder
del text_encoder_2
del tokenizer
del tokenizer_2
del pipeline

flush()

# %%
pipeline = FluxPipeline.from_pretrained(
    ckpt_id,
    text_encoder=None,
    text_encoder_2=None,
    tokenizer=None,
    tokenizer_2=None,
    vae=None,
    torch_dtype=torch.bfloat16,
)

# %% 
# pipeline.to("cuda")

# %%
pipeline.enable_sequential_cpu_offload()

# %%
print("Running denoising.")
height, width = 1360, 1360
latents = pipeline(
    prompt_embeds=prompt_embeds,
    pooled_prompt_embeds=pooled_prompt_embeds,
    num_inference_steps=33,
    max_sequence_length=256,
    guidance_scale=3.5,
    height=height,
    width=width,
    output_type="latent",
    generator=torch.Generator().manual_seed(randint(0,1000))
).images

# %%
del pipeline.transformer
del pipeline

flush()

# %%
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to(
    "cuda"
)
vae.enable_slicing()
vae.enable_tiling()
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)

# %%
with torch.no_grad():
    print("Running decoding.")
    latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
    latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor

    image = vae.decode(latents, return_dict=False)[0]
    image = image_processor.postprocess(image, output_type="pil")

# %%
del vae

flush()

# %%
image[0].show()
johnwick123f commented 2 weeks ago

@JakoLex The reason it was very slow when you were doing pipe.to('cuda') was because it was using shared ram(basically cpu ram) which massively slows down inference.

I would highly recommend doing this with Flux.1 Dev: https://gist.github.com/sayakpaul/e1f28e86d0756d587c0b898c73822c47

This should massively boost inference speed and also use much less vram.