saharmor / dalle-playground

A playground to generate images from any text prompt using Stable Diffusion (past: using DALL-E Mini)
MIT License
2.76k stars 596 forks source link

[OOM] Does the maker know what those 'partial(jax.pmap...' actually do? #28

Open DavidHiggis opened 2 years ago

DavidHiggis commented 2 years ago

Or just mindlessly copypasta?

I tried to solve the OOM problem and found out that those replicate(...) /flax.jax_utils.replicate are the major cause. That replicate thing literally doubled the ram usage and made mega-fp16 to crash on a 12gb colab vm.

Till now what I know is:

  1. To remove 'replicate', function prefix like @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(0)), @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6, 7)) should also be removed. Here's a non-replicate example of p_decode:
import cv2
from io import BytesIO
import random
from functools import partial

import jax
import numpy as np
import jax.numpy as jnp
from PIL import Image

from vqgan_jax.modeling_flax_vqgan import VQModel

from flax.jax_utils import replicate

# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)

'''
without @partial(jax.pmap... 'indices' shape should be (?,256), not (1,?,256)
'''
def p_decode(indices, params):
  return vqgan.decode_code(indices, params=params)

'''
encoded_images is a 1D seq generated by DalleBart,generate()

'''
def generate_images_from_ndarray(encoded_images):
  encoded_images=encoded_images.reshape((-1,256))
  images = []
  # decode images
  decoded_images = p_decode(encoded_images, vqgan.params)
  decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
  for img in decoded_images:
      images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))

  return images
  1. Codes above works. By the 'should be (?,256), not (1,?,256)' rule above, I remove the prefix above p_generate, and modify tokenize_prompt(self, prompt: str) to:

    def tokenize_prompt(self, prompt: str):
        tokenized_prompt = self.processor([prompt])
        zet = replicate(tokenized_prompt)
        for ky in zet:
          zet[ky]=zet[ky][0]
        return zet

    and generate_images to:

    def generate_images(self, prompt: str, num_predictions: int):
        tokenized_prompt = self.tokenize_prompt(prompt)
    
        # create a random key
        seed = random.randint(0, 2 ** 32 - 1)
        key = jax.random.PRNGKey(seed)
    
        # generate images
        images = []
        # get a new key
        key, subkey = jax.random.split(key)
    
        encoded_images = p_generate(
            tokenized_prompt,
            subkey,
            self.params,
            GEN_TOP_K,
            GEN_TOP_P,
            TEMPERATURE,
            COND_SCALE,
            self.model
        )
    
        # remove BOS
        encoded_images = encoded_images.sequences[..., 1:]
    
        # decode images
        decoded_images = p_decode(self.vqgan, encoded_images[0], self.vqgan_params)
        decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
        for img in decoded_images:
            images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))
    
        return images

    But no luck by now.

If the maker of this actually understand his codes (guess not, otherwise this line 'for i in range(max(num_predictions // jax.device_count(), 1))' should have been optimized away, since 'dalle_model.generate_images("warm-up", 1)' in app.py),

please tell us how to remove these pmap/batch thing which made for cluster VMs and NOT for standalone VM enviroment like free-tier Google Colab GPU.

YukiSakuma commented 2 years ago

I'm trying to solve it too but on kaggle https://github.com/borisdayma/dalle-mini/issues/204, the replicate function is the culprit, I tried loading the mega fp16 model first then do the replicate (before loading the vqgan model) but still it OOMs.

ghost commented 2 years ago

If this code works, you could put it in a pull request and have it merged. I, for one, would love less VRAM usage, especially since my card is a gigabyte short of the twelve required to run mega_full.