Open melih-unsal opened 2 years ago
You don't need to touch the PRNG keys, one per device/batch is enough. For example, this processes two prompts in parallel on one device:
tokenized_prompt = processor(["avocado chair", "the Eiffel tower landing on the moon"])
tokenized_prompt = replicate(tokenized_prompt)
I haven't dug into the code so I don't know if this has the same amount of entropy as submitting multiple batches, but if you make the prompts the same, the results certainly don't look any less randomized.
Hello, In the notebook, the images are generated with the for loop below.
`for i in range(num_predictions // jax.device_count()):
get a new key
Is it possible to give multiple subkeys to get the images with a single run? Another possibility would be using torch.multiprocessing but everytime i tried it gives the error and shutsdown:
UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown