basujindal / stable-diffusion

Optimized Stable Diffusion modified to run on lower GPU VRAM
Other
3.14k stars 469 forks source link

Batch of multiple prompts executed at once as a batch #144

Open lazy-nurd opened 2 years ago

lazy-nurd commented 2 years ago

Thank you so much for your wonderful work. Is there a possibility of batching multiple prompts and getting them inferred at once ( Considering 1 image to produce per prompt & Considering same size of output for all the prompts ) ? If we can do this, this will help us serve multiple requests from users as a batch to our model at once. Thanks

bitRAKE commented 1 year ago

https://github.com/basujindal/stable-diffusion/blob/e2aef31f3e1a9b9297a0fc9fbb02a90308c98699/optimizedSD/optimized_txt2img.py#L242 ... if you change the above line, removing the sorting ...

data = list(chunk(data, batch_size))

... you'll get the desired effect. Just gather your prompts in a file and batches will produce one image for each line in the file, in order. 💯

bitRAKE commented 1 year ago

Well, I almost had it. That'll go through the prompts opt.n_samples times!

Change the else: condition to the following:

else:
    print(f"reading prompts from {opt.from_file}")
    with open(opt.from_file, "r") as f:
        data = f.read().splitlines()
        data = list(chunk(list(data), batch_size))

... it'll wrap the prompt list around to complete the final batch if it's not a multiple of batch_size. Could multiply the list by opt.n_iter and then chunk() for smoother results, but then set opt.n_iter=1 afterward:

    print(f"reading prompts from {opt.from_file}")
    with open(opt.from_file, "r") as f:
        data = f.read().splitlines()
        data = opt.n_iter * list(data)
        opt.n_iter = 1
        data = list(chunk(list(data), batch_size))

There are just a lot of possible usage scenarios.

Perhaps there is enough here to get you exactly what you need.

bitRAKE commented 1 year ago

chunk() doesn't work the way I thought it did. So, I wrote one that does: 😄

def list_chunk(L,N): # irregular chunk with complete coverage (random overflow)
    for x in range((len(L)+N-1)//N): # round up
        for y in range(N):
            k = x*N + y
            if k >= len(L):
                k = np.random.randint(len(L))
            yield L[k]
lazy-nurd commented 1 year ago

Does that work ? Since the model then executes each prompt separately https://github.com/basujindal/stable-diffusion/blob/e2aef31f3e1a9b9297a0fc9fbb02a90308c98699/optimizedSD/optimized_txt2img.py#L255

What I want is to execute the model once with let's suppose 4 prompts at the same time. Thanks

bitRAKE commented 1 year ago

data is a list, but a list of what?

Batching is possible because data is a list of iterators. Currently, by multiplying the list of prompts by batch_size and sorting; we are left with iterators of the same prompt batch_size times.

Without sorting, the iterators are composed of batch_size items of different prompts. If you try it - you will see this.

Here is a little example to try:

batch = 3
xlist = [1,2,3,4,5]
xlist = xlist * batch
v = list(chunk(xlist, batch))
for w in v:
    for x in w:
        print(x, end='')
    print()

123 451 234 512 345

Each line is analogous to a batch, each number presents in order. Whereas if sorted [the current behavior] we would see an output of each line being the same number repeated.