keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
976 stars 318 forks source link

[Performance] StableDiffusion on Mult-GPU #1455

Closed innat closed 1 year ago

innat commented 1 year ago

Issue

I run stable diffusion model (v1, v2) on single gpu and mult-gpu to observe the execution performance of the model. And get some possible issue.

Is it expected?

Also, for mult-gpu (2x T4), I'm using kaggle-env. I noticed that both gpu's gpu-ram is full but only one gpu is being used.

Reproduce

Colab gist. Data-set.

cc. @LukeWood @ianstenbit @bhack

others

innat commented 1 year ago

Okay, I think I need to distribute the dataset (here which is prompt).

dataset = tf.data.Dataset.from_tensor_slices(df_train.sample(20).Prompt.values)
dataset = dataset.batch(12)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
dataset = strategy.experimental_distribute_dataset(dataset)

Follow-up observaiton, does the current SD take tf.data as input?

innat commented 1 year ago

@miguelCalado I found your reported issue here, and I think we're in the same page. Did you find any workaround?

miguelCalado commented 1 year ago

Hi @innat! Yes, I'm not sure if this solution is correct, but when I originally implemented the prompt-to-prompt paper I had it working with multiple GPU (meanwhile I changed it to comply with keras_cv API).

You can check it in the README.md of my repo.

Solution:

  1. Add a distribution strategy according to your hardware (more details here).
  2. The models need to be loaded with this distribution - Example.
  3. Replace all predict_on_batch with predict. I believe that the strategies don't work well with predict_on_batch.

Let me know if this helps you!

innat commented 1 year ago

@miguelCalado Thanks for the hint. I tried as you suggested but unfotunately didn't see any change. Does it work on your case (ptp-tf)?

miguelCalado commented 1 year ago

Yes, I can say that in my case it worked... Maybe you could share a notebook to see if I can spot any errors?

innat commented 1 year ago

Here is the gist.
I repalced predict_on_batch with predict as you suggested, modified script. You may need to run it mult-gpu devices (gcp or use kaggle). Sample data.

miguelCalado commented 1 year ago

I think the issue is that for the first time you run inference, the models are downloaded and loaded. So doing this, for example:

image_gen_model = kcv_models.StableDiffusionV2(
         512, 512, jit_compile=True
    )

Does not load or download the model.

The solution is to:

strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])

with strategy.scope():
    image_gen_model = kcv_models.StableDiffusionV2(
         512, 512, jit_compile=True
    )
    # to warm-up
    _ = generate_image(
            "A"*77, image_gen_model, image_per_prompt=2
        )

This way you get both GPUs working as demonstrated in the GIF (don't judge me I'm not the best screen recorder :D). Not the prettiest and probably not the best way, but I think it does the job.

I modified your Notebook so you can have more examples.

Let me know how it went on the Kaggle setup @innat!

ezgif-2-727de39f4f

Env: I used 2xA6000 GPUs from Paperspace using my custom environent

innat commented 1 year ago

@miguelCalado Thanks. Your solution does work. Also, replacing predict with predict_batch is not required. The only things I need to adopt is to call above generate_image within strategy scope (a bit puzzle to me still, as I already invoked the model). Anyway, thanks again. 😄

with strategy.scope():
    image_gen_model = kcv_models.StableDiffusionV2(
         512, 512, jit_compile=True
    )
    # to warm-up
    _ = generate_image(
            "A"*77, image_gen_model, image_per_prompt=2
        )
innat commented 1 year ago

Hi @miguelCalado

The above approach you suggseted does make multi-gpu in action but by doing so whether the program is running on one device or both, the execution time is quite different. Here is one demonstration.

# Tesla-T4 (2X)
strategy = tf.distribute.MirroredStrategy(["GPU:0"]) # ["GPU:0", "GPU:1"]
keras.mixed_precision.set_global_policy("mixed_float16")
batch_size = 2
def stable_diffusion_model(input_shape):
    with strategy.scope():
        model = kcv_models.StableDiffusionV2(
            *input_shape, jit_compile=True
        )

        # to warm-up
        _ = generate_image(
            "A"*77, model, 
            image_per_prompt=strategy.num_replicas_in_sync,
            ug_scale=9
        )
    return model

50/50 [==============================] - 27s 543ms/step
50/50 [==============================] - 27s 534ms/step
50/50 [==============================] - 27s 531ms/step
50/50 [==============================] - 27s 531ms/step
50/50 [==============================] - 27s 534ms/step
def stable_diffusion_model(input_shape):
    with strategy.scope():
        model = kcv_models.StableDiffusionV2(
            *input_shape, jit_compile=True
        )
    return model

50/50 [==============================] - 19s 379ms/step
50/50 [==============================] - 19s 378ms/step
50/50 [==============================] - 19s 385ms/step
50/50 [==============================] - 19s 383ms/step
50/50 [==============================] - 19s 381ms/step

Full code is same as above.

So, when I tried to warm-up the model and download its weights within with strategy.scope(), the execution time (for batch_size=2) is increased from 19s to 27s. WDYT?

@LukeWood Also, as I said, warming up model inside strategy scope looks anti-pattern, not like usual keras model, doesn't it. It's like calling model(tf.ones(shape=..)) inside the scope, similar to model.text_to_image. About that, what would be the way to run the official code example on multiple devices?

innat commented 1 year ago

@miguelCalado Apart from the above situation, here is another observation regarding runtine using variable batch size. For example, with the approach you suggested, if I set batch_size=10, the total time takes (using 2 GPUs) around 54 seconds (5.4 sec each for a single image generation). And with batch_size=30, it is 127 second (4.2 sec each).