huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.42k stars 5.44k forks source link

Flax Stable Diffusion 2.1 error #5224

Closed entrpn closed 1 year ago

entrpn commented 1 year ago

Describe the bug

When trying to load Stable Diffusion 2.1 using Flax, I am getting the following error:

Traceback (most recent call last):
  File "/home/jfacevedo/infer.py", line 120, in <module>
    run(opt)
  File "/home/jfacevedo/infer.py", line 30, in run
    pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
  File "/home/jfacevedo/.local/lib/python3.10/site-packages/diffusers/pipelines/pipeline_flax_utils.py", line 535, in from_pretrained
    raise ValueError(
ValueError: Pipeline <class 'diffusers.pipelines.stable_diffusion.pipeline_flax_stable_diffusion.FlaxStableDiffusionPipeline'> expected {'vae', 'scheduler', 'feature_extractor', 'text_encoder', 'safety_checker', 'tokenizer', 'unet'}, but only {'vae', 'scheduler', 'text_encoder', 'tokenizer', 'unet'} were passed.

Reproduction

Create a TPU VM and run the following installation:

git clone https://github.com/huggingface/diffusers.git
cd diffusers
pip install .
cd ..
pip install transformers flax
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

The run the following as follows:

python infer.py --sd-version 2 --itters 3

import time
import argparse
import numpy as np
import jax
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

def image_grid(imgs, rows, cols):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

def run(opt):
    if opt.sd_version == 1:
        pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4",
            revision="bf16",
            dtype=jnp.bfloat16
        )
    else:
        pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-1",
            revision="bf16",
            dtype=jnp.bfloat16
        )

    p_params = replicate(params)
    rng = create_key(0)
    rng = jax.random.split(rng, jax.device_count())
    prompts = ["Labrador in the style of Hokusai"] * opt.batch_size
    print("prompts len:",len(prompts))

    prompt_ids = pipeline.prepare_inputs(prompts)
    prompt_ids = shard(prompt_ids)

    # Default values https://github.com/huggingface/diffusers/blob/v0.14.0/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py#L275
    num_inference_steps = 50
    height = opt.height 
    width = opt.width 
    guidance_scale = 7.5
    g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
    g = g[:, None]  # shape: (prompt_ids.shape[0], 1)

    # num_inference_steps, height, width, and guidance_scale are static, so need to 
    # specify their positions in the _generate() function as an array to static_broadcasted_argnums
    p_generate = pmap(pipeline._generate, static_broadcasted_argnums=[3,4,5])

    print("Sharded prompt ids has shape:", prompt_ids.shape)
    print("Guidance shape:",g.shape)

    s = time.time()
    images = p_generate(prompt_ids, p_params, rng, num_inference_steps, height, width, g)
    images = images.block_until_ready()
    print("First inference time is:", time.time() - s)

    iters = opt.itters 
    s = time.time()
    for _ in range(iters):
        images = p_generate(prompt_ids, p_params, rng, num_inference_steps, height, width, g)
        images = images.block_until_ready()
    print("Second inference time is:", (time.time() - s)/iters)
    print("Shape of predictions is: ", images.shape)

    if opt.trace:
        trace_path = "/tmp/tensorboard"
        with jax.profiler.trace(trace_path):
            images = p_generate(prompt_ids, p_params, rng, num_inference_steps, height, width, g)
            images = images.block_until_ready()
            print(f"trace can be found: {trace_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--batch-size',
        type=int,
        default=4,
        help='Number of images to generate'
    )
    parser.add_argument(
        '--sd-version',
        type=int,
        default=1,
        help='Use 1 for SD 1.4, Use 2 for SD 2.1'
    )
    parser.add_argument(
        '--width',
        type=int,
        default=512,
        help='Width'
    )
    parser.add_argument(
        '--height',
        type=int,
        default=512,
        help='Height'
    )
    parser.add_argument(
        '--itters',
        type=int,
        default=15,
        help='Number of itterations to run the benchmark.'
    )
    parser.add_argument(
        '--trace',
        action="store_true", 
        default=False, 
        help="Run a trace"
    )

    opt = parser.parse_args()
    run(opt)

Logs

The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'downsample_padding': 1, 'dual_cross_attention': False, 'mid_block_scale_factor': 1, 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Traceback (most recent call last):
  File "/home/jfacevedo/infer.py", line 120, in <module>
    run(opt)
  File "/home/jfacevedo/infer.py", line 30, in run
    pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
  File "/home/jfacevedo/.local/lib/python3.10/site-packages/diffusers/pipelines/pipeline_flax_utils.py", line 535, in from_pretrained
    raise ValueError(
ValueError: Pipeline <class 'diffusers.pipelines.stable_diffusion.pipeline_flax_stable_diffusion.FlaxStableDiffusionPipeline'> expected {'vae', 'scheduler', 'feature_extractor', 'text_encoder', 'safety_checker', 'tokenizer', 'unet'}, but only {'vae', 'scheduler', 'text_encoder', 'tokenizer', 'unet'} were passed.
I0000 00:00:1695928082.495501    4918 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.

System Info

Who can help?

No response

pcuenca commented 1 year ago

Thanks a lot for the report, and sorry for the trouble @entrpn!

Workaround:

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1",
    revision="bf16",
    dtype=jnp.bfloat16,
    safety_checker=None,
    feature_extractor=None,
)

The proper solution would be to replicate something like #1395 in pipeline_flax_utils.py, which I suspect would take a little bit of effort to get right before merge. Will iterate on it in the next few days! /cc @patrickvonplaten

entrpn commented 1 year ago

That you @pcuenca, the workaround works for me.

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.