borisdayma / dalle-mini

DALL·E Mini - Generate images from a text prompt
https://www.craiyon.com
Apache License 2.0
14.75k stars 1.2k forks source link

multiprocessing #197

Open melih-unsal opened 2 years ago

melih-unsal commented 2 years ago

Hello, In the notebook, the images are generated with the for loop below.

`for i in range(num_predictions // jax.device_count()):

get a new key

  key, subkey = jax.random.split(key)

  # generate images
  encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey),
      model.params,gen_top_k, gen_top_p, temperature, cond_scale,
  )

  # remove BOS
  encoded_images = encoded_images.sequences[..., 1:]

  # decode images
  decoded_images = p_decode(encoded_images, vqgan.params)
  decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
  for img in decoded_images:
      images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))`

Is it possible to give multiple subkeys to get the images with a single run? Another possibility would be using torch.multiprocessing but everytime i tried it gives the error and shutsdown:

UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown

drdaxxy commented 2 years ago

You don't need to touch the PRNG keys, one per device/batch is enough. For example, this processes two prompts in parallel on one device:

tokenized_prompt = processor(["avocado chair", "the Eiffel tower landing on the moon"])
tokenized_prompt = replicate(tokenized_prompt)

I haven't dug into the code so I don't know if this has the same amount of entropy as submitting multiple batches, but if you make the prompts the same, the results certainly don't look any less randomized.