facebookresearch / ToMe

A method to increase the speed and lower the memory footprint of existing vision transformers.
Other
931 stars 67 forks source link

Stable Diffusion #4

Closed lalalune closed 1 year ago

lalalune commented 1 year ago

This work looks incredible. Is there an ETA on you releasing the SD patch, would we be reinventing the wheel to attempt to add the patch ourselves / would you accept a PR or is there a coming plan for that?

dbolya commented 1 year ago

I haven't had time to finalize the stable diffusion experiments mostly because I've been busy and there are a lot of variables at play (where and how to apply ToMe to get the best results). I plan to have some experiments done in the next two weeks, after which I'll release a patch.

But in the mean time, if you or anyone else wants to play around with ToMe + Stable Diffusion and release a PR for it, that would be great. It's looking like the patch will have multiple options anyway, depending on what you want stylistically / computationally. So, if you get good results with ToMe I'd be happy to merge patches.

Birch-san commented 1 year ago

I had a stab at weaving ToMe into CompVis stable-diffusion:
https://github.com/Birch-san/stable-diffusion/compare/ca4c1d96f252845634f9df6bc34eedf3c4449b1c...bb7ae37f0c0ab56279762cb285ba7e67fe41bd01

any idea what I did wrong?

unm_idx's shape is (4, 1). I cannot expand it to (2, 2044, 320) (i.e. [n, t1 - r, c]):

image

so either t1 is wrong (implying x is wrong), or unm_idx is wrong (implying metrics is wrong).

x is the model output. I assume I can consider that known-good.

so, did I compute metrics wrong? i.e. the mean of the key for cross-attention.

k begins its life with shape (2, 4096, 320),
gets rearranged to (16, 4096, 40),
then k.mean(1) has shape (16, 40).

====

I get a little further if I instead use k.mean(0), whose shape is (4096, 40):
https://github.com/Birch-san/stable-diffusion/commit/53e15dd9006f08823cdab0d277b370adcf5f60f1

the self-attention layer completes.
next, I get a similar problem on the cross-attention layer:

The expanded size of the tensor (2042) must match the existing size (35) at non-singleton dimension 1.  Target sizes: [2, 2042, 320].  Tensor sizes: [35, 1]
image

during cross-attention: k begins its life with shape (2, 77, 320).
gets rearranged to (16, 77, 40).
k.mean(0) has shape (77, 40)

this tiny k.mean is why unm_idx's dim1 failed to match the model output's dim1.
the approach worked for self-attention, but maybe that was just a fluke.

any idea what I'm doing wrong?
would be great to get this working to speed up attention on Mac — we haven't had access to Flash Attention or CUDA-optimized attentions.

dbolya commented 1 year ago

You want to take the mean over attn heads. So when k starts out as (2, 77, 320), you want to extract the attn heads: (2, 77, h, 320 // h) and then take the mean over the h dimension: k.mean(-2).

Alternatively you can forgo the mean entirely and just use the (2, 77, 320) tensor. Slower, but could be more accurate.

Birch-san commented 1 year ago

thanks! I've changed it as such:
https://github.com/Birch-san/stable-diffusion/commit/42324b8ab685512292b4842fdacdc305ca1a242c

i.e. I reshape the key, from
[batch_size, token_count, dim_head * heads] to:
[batch_size, token_count, heads, dim_head]

then I compute k.mean(-2), giving metric with shape
[batch_size, token_count, heads]

for self-attention, this turns the key [2, 4096, 320] into:
[2, 4096, 40]
for cross-attention, this turns the key [2, 77, 320] into:
[2, 77, 40].

as for foregoing the mean… sure, will look into that once this is working (but speed is better for Mac users!)

====

hm, this hasn't really changed the situation sadly.
I come out of cross-attention with metrics of shape [2, 77, 40]. I enter bipartite_soft_matching, and it makes an unm_idx which is similarly small, at [2, 35, 1]:

with torch.no_grad():
    metric = metric / metric.norm(dim=-1, keepdim=True)
    a, b = metric[..., ::2, :], metric[..., 1::2, :]
    scores = a @ b.transpose(-1, -2)
metric.shape
torch.Size([2, 77, 40])
a.shape
torch.Size([2, 39, 40])
b.shape
torch.Size([2, 38, 40])
scores.shape
torch.Size([2, 39, 38])
edge_idx.shape
torch.Size([2, 39, 1])
unm_idx.shape
torch.Size([2, 35, 1])

then, merge_wavg goes bang in the same place as before:

image
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
    src, dst = x[..., ::2, :], x[..., 1::2, :]
    n, t1, c = src.shape
    unm_idx_ex = unm_idx.expand(n, t1 - r, c)
x.shape
torch.Size([2, 4092, 320])
src.shape
torch.Size([2, 2046, 320])
dst.shape
torch.Size([2, 2046, 320])
unm_idx.shape
torch.Size([2, 35, 1])
(n, t1 - r, c)
(2, 2042, 320)
The expanded size of the tensor (2042) must match the existing size (35) at non-singleton dimension 1.  Target sizes: [2, 2042, 320].  Tensor sizes: [2, 35, 1]

it feels like I'm missing something when it comes to cross-attention.
x's shape is way bigger (latents) than metric's shape (CLIP text embeddings).

is it wrong to compute metrics from k? or is it wrong to run the merge upon x?

====

I've also noticed that after running the merge_wavg() for self-attention: the shape the returned x is different from the x I input into it:

# before merge_wavg()
x.shape
torch.Size([2, 4096, 320])
# after merge_wavg()
x.shape
torch.Size([2, 4092, 320])

is that expected? my self.r = 4, so maybe that means 4 tokens were merged.
but it confuses the rearrange that happens downstream:

https://github.com/Birch-san/stable-diffusion/blob/42324b8ab685512292b4842fdacdc305ca1a242c/ldm/modules/attention.py#L349

Error while processing rearrange-reduction pattern "b (h w) c -> b c h w".
Input tensor shape: torch.Size([2, 4092, 320]). Additional info: {'h': 64, 'w': 64}.
Shape mismatch, 4092 != 4096

am I missing some countermeasure? I didn't see any accommodation's having to be made for a change-in-size in the timm or swag examples.

dbolya commented 1 year ago

it feels like I'm missing something when it comes to cross-attention.

Ah my bad. I forgot to mention you probably shouldn't do this to cross-att. The keys for cross attention are over the prompt, not the image. So for cross attn you should merge the prompt (context) not x. But the prompt is only a small number of tokens so it's not worth it.

so maybe that means 4 tokens were merged.

Yeah, this is exactly what's happening. The speed / memory savings in tome comes from having fewer tokens to compute on.

Of course, we need those tokens to get an image in the end. There are a few ways of dealing with this. Either you do some computation with the reduced set of tokens and the you "unmerge" afterward, or you can put this in a spot that doesn't need to be unmerged.

For the examples, I did the latter: put this inside of self attn rather than the transformer block. Then instead of merging x, I merged k and v. This still reduces memory consumption and increases speed, but since kv are used for attn and then discarded you don't need to unmerge.

For that setting I recommend a really high value for r. The default biparite algorithm can go up to r=1/2 the current tokens. For greater values of r, you can use the random algorithm I include in the merge.py file. That supports any number (up to the total number of tokens) for r.

Birch-san commented 1 year ago

thanks for clarifying!

okay, will avoid doing this for cross-attn. yes, unlikely to be able to produce a speed-up over so few tokens. shame though, since prompts under the length limit can be expected to have plenty of embeddings for end-of-sentence tokens, which are prime candidates for merging (well, maybe we can merge away the end-of-sentence embeddings with a mechanism far simpler than ToMe, using knowledge of how many tokens long the original prompt was).

put this inside of self attn rather than the transformer block. Then instead of merging x, I merged k and v.

I've had a stab at this:
https://github.com/Birch-san/stable-diffusion/commit/434a36b54f483aff66df22a215a427385e576fde

this successfully makes k smaller by 1024 tokens:

k.shape # (before)
torch.Size([2, 3072, 320])
k.shape # (after)
torch.Size([2, 4096, 320])

but I'm unsure what to do about the self._tome_info.size tensor that they share.
after merging k, I try to merge v, but merge_wavg complains:

The size of tensor a (4096) must match the size of tensor b (3072) at non-singleton dimension 1
image
x = merge(x * size, mode="sum")
x.shape
torch.Size([2, 4096, 320])
size.shape
torch.Size([2, 3072, 1])

am I supposed to change these two merges into one merge? perhaps by concatting kv first?

or am I supposed to give them separate size tensors?

dbolya commented 1 year ago

A couple things: you should only generate one merge function using k and then use the same function for both k and v. The keys and values are paired together in attention, so whatever you do to one you need to do to the other.

Then, since k and v are discarded afterward there is no need to keep track of token size. In fact there is no need to use wavg at that point. You can just directly pass k and v into the generated merge function with the mode set to "mean" (which I think is the default).

The size of tensor a (4096) must match the size of tensor b (3072) at non-singleton dimension 1

This error comes from using k's size for v when they should both be using the input size (which will always be None, so there's no point in keeping track).

lalalune commented 1 year ago

@Birch-san I have submitted a patch PR here: https://github.com/lalalune/stable-diffusion/commit/f7bb69b806434423df9dfe6a0b949c2317c2b261

I followed the advice from @dbolya above. The error was coming from setting self._tome_info.size -- You had r=1024, so size (4096) - r (1024) = 3072, hence the numbers you were seeing

Setting r=2048, I consistently am getting 20% speed increase... quite good for not a lot of code, but not 200%. I'm wondering if this could be improved?

dbolya commented 1 year ago

@lalalune yeah it's not much because the version I described only speeds up a small portion of the model. I'm playing around with other techniques but it's difficult not to affect the fidelity of the generated images.

FYI if you use the random bipartite matching function or whatever I called it, you can specify any number for r.

The strategy I used for the 4k images was to set a maximum number of tokens in k or v. If there are more, then you set r to be that difference.

In a 4k image, there are 32,000 tokens at the start, which means attn is a 32000x32000 matrix. I think I set the max to 8-12k which is how I was able to fit that in memory.

Birch-san commented 1 year ago

thanks @lalalune for the patch. I've incorporated it partially, and have exposed some useful params by which to determine the token reducing schedule.

https://github.com/Birch-san/stable-diffusion/pull/3

thanks also @dbolya; you've been hugely helpful! I've managed a 12% speed increase on a 512x512 image by applying kth_bipartite_soft_matching(k=4) when token_count >= 4096, and disabling ToMe for smaller matrices.

hopefully can get a more dramatic perf difference on bigger images. will try that next.

Birch-san commented 1 year ago

was able to get it generating 1536x1536 (well, 2048x2048, but Metal crashes due to a Mac bug) with this schedule:
https://github.com/Birch-san/stable-diffusion/commit/72184577a40715efbb159f502f04d4904f2949a6

result:
https://twitter.com/Birchlabs/status/1585782506384941056

yuvalkirstain commented 1 year ago

@dbolya Do you still plan to enable this for stable diffusion?

dbolya commented 1 year ago

@yuvalkirstain Yeah, I've just been unexpectedly busy (wrote 2 other papers in the mean time 💀).

I did decide on an implementation (which is what I use in the final version of the paper (Appendix F)), but haven't gotten around to uploading it as a patch.

It's literally 4 lines of code tho. Replace these lines with

m, u = tome.merge.bipartite_soft_matching(x, r=int(x.shape[1]*ratio_to_merge))

x = u(self.attn1(m(self.norm1(x)))) + x
x = u(self.attn2(m(self.norm2(x)), context=context)) + x
x = u(self.ff(m(self.norm3(x)))) + x

Set ratio_to_merge to something between 0 and 0.5, higher means faster / less memory. Check Appendix F of the paper to see what increasing this number does (0.5 can get quite cursed). Of course, you need to import tome as well.

JustinMeans commented 1 year ago

@dbolya I appreciate your work! Any idea how to implement this in Diffusers? I've tried to transpile your 4 lines of code for the CompVis repo to the Diffusers implementation here but am hitting issues with the reduction dimension sizes. Python is not my strong suit. line 59, in bipartite_soft_matching node_max, node_idx = scores.max(dim=-1) IndexError: max(): Expected reduction dim 3 to have non-zero size.

Mxbonn commented 1 year ago

I did decide on an implementation (which is what I use in the final version of the paper (Appendix F)), but haven't gotten around to uploading it as a patch.

Any reason why this appendix was dropped in the new final version?

dbolya commented 1 year ago

Good news! I've written an entirely new short paper about ToMe for Stable Diffusion, improving heavily over the naïve approach I described above. Try it out at this GitHub link! For any more discussion on ToMe for Stable Diffusion, head over there.

Any reason why this appendix was dropped in the new final version?

Some of the authors were weary of using external models like Stable Diffusion due to their training data (understandable, considering the current legal issues surrounding the model). Also, the initial results I had for ToMe + SD were okay, but not super great. I wanted to post a version that people could actually comfortably use, hence the new experiments above.