borisdayma / dalle-mini

DALL·E Mini - Generate images from a text prompt
https://www.craiyon.com
Apache License 2.0
14.74k stars 1.21k forks source link

Parallelizing "super conditioning" sampling speeds up inference by 20-40% #247

Open drdaxxy opened 2 years ago

drdaxxy commented 2 years ago

The current implementation of DalleBart._sample() performs "super conditioning" by running DalleBart.decode() twice in a row:

https://github.com/borisdayma/dalle-mini/blob/a72705f46ce29a45d1c56c40b39e12476bfa6584/src/dalle_mini/model/modeling.py#L1838-L1856

Allowing JAX to schedule these in parallel makes inference much faster, especially but not only for small batches - 20%-40% lower seconds per image on my hardware, with no downside I can see (couple MB more memory used on small batches, couple hundred MB less on larger ones, output should be unaffected).

I hacked this in by vmap()-ing a few functions to stack all the _uncond-suffixed data on top of the input prompt equivalents.

You're probably quicker at cleanly integrating this simple change than me, hence no pull request, but here's an inference notebook demonstrating it[^1], no edits required, with more info / benchmark results.

[^1]: Vectorization, as an addendum, achieves avocado armchair assembly accuracy amplification far beyond 40% over the "avocado armchair" baseline, also proven in this notebook 😁

borisdayma commented 2 years ago

That's cool! What happens if you max out your GPU memory? I imagined that 2 images with no super conditioning (scale of 1) are the same as 1 image with (scale > 1). Maybe even more at large scale, but I can be wrong.

Very impressive notebook btw 🙂

drdaxxy commented 2 years ago

First off: I found that making many predictions for few prompts, vectorized across PRNG keys, is much faster - I reach 1 s/image in float32 with 30 predictions for one prompt (but e.g. 10 keys, 3 prompts is close enough). I bet "pmap over few prompts, vmap over many keys" is faster on TPUs too, if you can make that work. This only helped after I replaced the lax.while_loop with fori_loop/scan (valid since our sequence length is fixed).

Anyway, about the original subject, without this optimization:


That's cool! What happens if you max out your GPU memory?

Tl;dr: On throughput-optimized batches on my system, p_parallel_cond_generate is 10%-20% faster.

In float16, p_generate grows 216 MiB per item, p_parallel_cond_generate grows 196 MiB/item. With batch size 80, p_generate takes 0.812 s/item in 23,839 MiB, p_parallel_cond_generate takes 0.731 s/item in 22,246 MiB, 10% speedup.

In float32, p_generate grows by 432 MiB/item, p_parallel_cond_generate by 393 MiB/item. With batch size 28, p_generate takes 1.732 s/item in 23,591 MiB, p_parallel_cond_generate takes 1.449 s/item in 22,564 MiB, 16% speedup. With batch size 30, p_generate doesn't fit[^1], but p_parallel_cond_generate takes 1.398s s/item in 23,352 MiB, 3.5% speedup vs. parallel n=28, 19% speedup vs. serial n=28.

[^1]: I'm also running a Windows desktop (idle, but still) so free VRAM varies. In float32 I think another item helps when there's enough. But if space is slightly too tight, the models run slowly in what they can get rather than crashing.

(Memory per item might depend on hardware, I'm not sure. I tried skipping VQGAN decode too, it saved some space but little time, didn't feel worth it in this setup.)


I imagined that 2 images with no super conditioning (scale of 1) are the same as 1 image with (scale > 1). Maybe even more at large scale, but I can be wrong.

Yes, p_generate takes me exactly double time and space[^2] with scale > 1 vs. scale = 1, but that's more than needed 🙂 Some decode() inputs are the same and we only need a weighted sum of the last logits. Haven't checked the computation graph but maybe p_parallel_cond_generate pushes that into the model.

[^2]: ...except, if I use float16 and large batches and VQGAN, MiB/item peaks at 128 during warmup for some reason, so I can't do n=160, but I think my bottleneck is compute by then anyway.

For the record: In float16, without super conditioning, n=128 maximizes my throughput with 0.367 s/item in 23,014 MiB warmup and 18,944 MiB steady-state. In float32 it's 0.652 s/item, n=60, 23,226 MiB.

Very impressive notebook btw 🙂

Thanks! I might touch it up a little soon if I can find the time.