basujindal / stable-diffusion

Optimized Stable Diffusion modified to run on lower GPU VRAM
Other
3.14k stars 469 forks source link

Memory-efficient attention and gradio mask fixed #117

Closed neonsecret closed 2 years ago

MrLavender commented 2 years ago

Nice work. Applying the attention.py change to the original SD lets me do 512x512 on 8GB, previously could only do 448x448.

But (on the original SD anyway) the size of sim is 16 so sim[8:] and sim[:8] is more memory efficient (makes the difference between it working or failing with out-of-memory). A more general way to do this would be;

half = int(sim.size(dim=0) / 2)
sim[:half] = sim[:half].softmax(dim=-1)
sim[half:] = sim[half:].softmax(dim=-1)

or for maximum memory efficiency (with about 1% performance difference for me);

for i in range(sim.size(dim=0)):
    sim[i] = sim[i].softmax(dim=-1)
Doggettx commented 2 years ago

I've found a way to split up the einsum too, can go to insane resolutions on my card now... Might be a better way to do this, my knowledge of torch and python is very limited (meaning almost 0 ;)

Also not quite sure if all the deletes are really needed, no idea when the garbage collector triggers for unused tensors, but guess can't hurt to force it.

def forward(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)
    k = self.to_k(context)
    v = self.to_v(context)
    del context, x

    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

    r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
    for i in range(0, q.shape[0], 4):
        end = i + 4
        s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
        s1 *= self.scale

        s2 = s1.softmax(dim=-1)
        del s1

        r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
        del s2

    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
    del r1

    return self.to_out(r2)
neonsecret commented 2 years ago

I've found a way to split up the einsum too, can go to insane resolutions on my card now... Might be a better way to do this, my knowledge of torch and python is very limited (meaning almost 0 ;)

Also not quite sure if all the deletes are really needed, no idea when the garbage collector triggers for unused tensors, but guess can't hurt to force it.

def forward(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)
    k = self.to_k(context)
    v = self.to_v(context)
    del context, x

    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

    r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
    for i in range(0, q.shape[0], 4):
        end = i + 4
        s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
        s1 *= self.scale

        s2 = s1.softmax(dim=-1)
        del s1

        r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
        del s2

    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
    del r1

    return self.to_out(r2)

it won't work, you are only multiplying parts and the whole tensor, the tensor for einsum shouldn't be split

Doggettx commented 2 years ago

I've found a way to split up the einsum too, can go to insane resolutions on my card now... Might be a better way to do this, my knowledge of torch and python is very limited (meaning almost 0 ;) Also not quite sure if all the deletes are really needed, no idea when the garbage collector triggers for unused tensors, but guess can't hurt to force it.

def forward(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)
    k = self.to_k(context)
    v = self.to_v(context)
    del context, x

    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

    r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
    for i in range(0, q.shape[0], 4):
        end = i + 4
        s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
        s1 *= self.scale

        s2 = s1.softmax(dim=-1)
        del s1

        r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
        del s2

    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
    del r1

    return self.to_out(r2)

it won't work, you are only multiplying parts and the whole tensor, the tensor for einsum shouldn't be split

Seems to work fine, gives same results, I have no idea how einsum works though, but as far as I can see there are no side effects

neonsecret commented 2 years ago

and memory?

Doggettx commented 2 years ago

and memory?

I went from being able to do 1920x640 to 1920x832, it's about 1/4th for the einsum now, I don't have any other optimizations though, only this one (from the compvis version)

neonsecret commented 2 years ago

hmm very weird

neonsecret commented 2 years ago

fucking hell it works

Doggettx commented 2 years ago

It actually works with steps of 2 as well, I can go to 1920x1024 then, it breaks at steps of 1, no idea how this stuff works hehe

Doggettx commented 2 years ago

Does seem to make it slower though

Doggettx commented 2 years ago

For comparison, I tested the same prompt/seed/settings etc. at different step sizes:

8 - 7.0 it/s 4 - 6.2 it/s 2 - 4.7 it/s

the drop from 8 to 4 isn't too bad, but not sure if to 2 is worth it. Unless you want to render really high

neonsecret commented 2 years ago

4 doesnt seem to make any difference for me I'm going to add both options

victorbessa96 commented 2 years ago

It would be great to have option to decide between faster renders or really high resolution, so perhaps an option to switch between 8 and 2?

JohnAlcatraz commented 2 years ago

I've found a way to split up the einsum too, can go to insane resolutions on my card now... Might be a better way to do this, my knowledge of torch and python is very limited (meaning almost 0 ;)

@Doggettx Wow, your code works amazingly well!

I can not see any significant slowdown, it works great even using a step amount of 1 in the for loop. I did also check that the output from the same seed is fully identical.

This is the speed I'm getting when measuring generating a 512x512 image, using a RTX 2070 Super:

The resolution I can do with a for-loop steps amount of 1 is incredible. It's fully worth the very small reduction in speed. But ideally, the amount of loop steps would be made a command line option that can be set.

So this is the code I'm using for a loop step amount of 1:

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)
        del context, x

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
        for i in range(0, q.shape[0], 1):
            end = i + 1
            s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
            s1 *= self.scale

            s2 = s1.softmax(dim=-1)
            del s1

            r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
            del s2

        r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
        del r1

        return self.to_out(r2)

With this code, I can do 1216x1216 on 8 GB VRAM. That is 4.4 times as many pixels compared to the maximum I can do with default SD. It's amazing!

To be clear, I did my testing above with default SD at half precision, not with the "optimized" version from this repo, so I was comparing default SD at half precision vs only the changed attention.py. With the other optimizations from this repo, I could surely go even higher than 1216x1216 on 8 GB VRAM now. But the other optimizations from this repo hurt speed a lot more, so I think they are not really worth doing any more now.

TheEnhas commented 2 years ago

How does this translate into doing batches of images though? One thing I tend to do is 20 512x512 50 step generations with turbo mode, how is VRAM use with half precision + the "loop step 1" code above on base SD compared to that? Because if it's much better or even comparable than yeah, the old optimizations shouldn't really be used anymore except maybe to have as an option to save even more on VRAM-limited (ie. 4GB or less) GPUs, or for really big images.

JohnAlcatraz commented 2 years ago

I noticed that the "step 1" version does not actually work for me too - I didn't pay attention to what exactly the log showed. I thought it run through to 100% and succeeded, but what it's actually doing is it runs through to 100%, but then crashes with an out of memory error at high resolutions. Lower resolutions work fine in the "step 1" code without crashes, but then I can also use the "step 2" version with a slightly higher speed.

There's probably some other code somewhere that needs to be optimized more for the "step 1" version to make sense and not crash at 100%.

So what I said above regarding "step 1" clearly being the best is not true. It's "step 2" that's the best because that actually works. The table I showed above is still accurate, just ignore the "loop steps of 1" row.

The maximum I can do now with 8 GB VRAM, using the "step 2" code, is 1.14 Megapixels, as mentioned in my previous comment. A factor of 2.91 improvement over default SD.

So this code:

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)
        del context, x

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
        for i in range(0, q.shape[0], 2):
            end = i + 2
            s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
            s1 *= self.scale

            s2 = s1.softmax(dim=-1)
            del s1

            r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
            del s2

        r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
        del r1

        return self.to_out(r2)
JohnAlcatraz commented 2 years ago

How does this translate into doing batches of images though? One thing I tend to do is 20 512x512 50 step generations with turbo mode, how is VRAM use with half precision + the "loop step 1" code above on base SD compared to that? Because if it's much better or even comparable than yeah, the old optimizations shouldn't really be used anymore except maybe to have as an option to save even more on VRAM-limited (ie. 4GB or less) GPUs, or for really big images.

Not sure what exactly you're asking about? I can of course still set --n_iter 20 and then it generates 20 images, the amount of images that are generated does not affect VRAM usage. What does affect VRAM usage is --n_samples, but I think there is no reason to ever have that higher than 1.

Doggettx commented 2 years ago
  • Default SD: 5.0 it/s | 0.39 Megapixels Max Res
  • Your modified def forward with loop steps of 8: 4.94 it/s | Didn't test Max Res
  • Your modified def forward with loop steps of 4: 4.87 it/s | 0.79 Megapixels Max Res
  • Your modified def forward with loop steps of 2: 4.78 it/s | 1.14 Megapixels Max Res
  • Your modified def forward with loop steps of 1: 4.46 it/s | 1.5 Megapixels Max Res

@JohnAlcatraz So weird to me that you see almost no difference, your steps 1 is actually faster than on my 3090, just wondering what OS are you using? and which version of torch? I'm running it in windows 11 with torch 1.12.1+cu116. Wonder if that can make a difference, I'm just running the default SD as well with some custom modifications but those have nothing to do with the rendering part.

JohnAlcatraz commented 2 years ago

@Doggettx I'm on Windows 10, 21H1. If I'd knew which version of torch I'm using I'd tell you, but I have no idea how to check that, I'm a C++ programmer with no clue about Python ;) I'm not usually doing anything with torch, only installed it for Stable Diffusion. So probably a very new version.

Maybe you are not running at half precision? That is a difference how I run it compared to fully default SD. Just adding that model.half(). Most forks by now do that by default.

Doggettx commented 2 years ago

I checked that to be sure, my model was running at full still, set that at half but doesn't really effect speed, it just allowed me to render at even higher res now (1920x1536 with only this change).

Think I'll just make it configurable in my version, for higher resolutions the speed difference seems to get less, but at low resolutions it's more than twice as slow and not really needed.

My workflow is usually first rendering with one dimension at 512 (so 512x768 or something) with a normal upscaler and then img2img the upscaled version at native res. Keeps coherence high while still allowing to render native at high resolutions. But then it's nicer if you can pump out those low res images fast to find a good one ;)

MrLavender commented 2 years ago

The optimization work done here in the last few hours really is awesome. Thank you all!

I know nothing about Machine Learning and never heard of an einsum before today but looking at the pytorch docs I see this interesting note;

This function does not optimize the given expression, so a different formula for the same computation may run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) can optimize the formula for you.

https://pytorch.org/docs/stable/generated/torch.einsum.html

So maybe there are further improvements to be had in this forward() function (in speed if not memory)?

willlllllio commented 2 years ago

This is crazy, with step=2 I can do 1088x1024 on a 6GB card with no noticeable extra slowdown, though I do need the cuda max_split arg for that res.

7flash commented 2 years ago

The only noticable optimization in this PR in these lines, halving of attention, but what does actually mean?

        sim[4:] = sim[4:].softmax(dim=-1)
        sim[:4] = sim[:4].softmax(dim=-1)

Seems like applying softmax separately to each half of array? Does it make it faster?

__ sema-logo  Summary: :question: I have a question  |  Tags: Efficient

CaptnSeraph commented 2 years ago

The second step works for me, helped me push my 8gb 1070 to 896x896

@willlllllio where do you specify the max_split? i assume you mean PYTORCH_CUDA_ALLOC_CONF but which file should that go into or do i need to type it each time as an environment variable.

also, what would be the ideal max size to set for a card with 8192mb?

willlllllio commented 2 years ago

@theseraphim just start the process with env var PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128 or whatever value. Don't think there's any file for that unless you just set it via python os.environ.

No clue what a good value is for this, only did it because in my case there was still way more mem free than it tried to allocate in the error. I could always do 1024x1024 without it.

patrickvonplaten commented 2 years ago

Really amazing work here! We also notice a huge drop in memory consumption in diffusers -> super nifty!

@neonsecret - would it be ok if we open a PR in https://github.com/huggingface/diffusers/tree/main/src/diffusers to allow a speed vs. required memory PR citing your amazing work from this PR?

neonsecret commented 2 years ago

Really amazing work here! We also notice a huge drop in memory consumption in diffusers -> super nifty!

@neonsecret - would it be ok if we open a PR in https://github.com/huggingface/diffusers/tree/main/src/diffusers to allow a speed vs. required memory PR citing your amazing work from this PR?

right

ryudrigo commented 2 years ago

Just to make it clearer to future developments and whoever reads this: Using

sim[4:] = sim[4:].softmax(dim=-1)

assumes unet_bs is set to 1. If you want to use a higher unet batch size (which is not that much heavier now with this approach), divide it in more chunks like in @patrickvonplaten's code or like this

att_step = 4

        for i in range (0, sim.shape[0], att_step):

            sim[i:i+att_step] = sim[i:i+att_step].softmax(dim=-1)

where b is batch size I can confirm this adittional chunking not introduce any noticeable slow down on my setup with a 3060, but reduces memory consumption from 12 GB to 8 GB for n_samples=4 and unet_bs=8 (for more information see issue #69)

neonsecret commented 2 years ago

nice

TheEnhas commented 2 years ago

Which version of the code was used in the update?

JohnAlcatraz commented 2 years ago

It seems like the original PR version was merged, which gives a lot less VRAM savings than the new optimization code by @Doggettx later figured out in this thread.

basujindal commented 2 years ago

It seems like the original PR version was merged, which gives a lot less VRAM savings than the new optimization code by @Doggettx later figured out in this thread.

Is there a PR request for the optimization discussed here?

JohnAlcatraz commented 2 years ago

No, no one made a new PR for it yet.

You can see the exact changes in the best way implemented in this branch by @Doggettx : https://github.com/Doggettx/stable-diffusion/commits/main

I don't know if he intends to open a PR himself with them?

ryudrigo commented 2 years ago

I just opened a PR, but it was just about my comment -- there might be other optimizations I didn`t look at

camenduru commented 2 years ago

1 step 1216x1216 on 8 GB VRAM with 1070 O8G 🎉 Thank You, Everyone.

JohnAlcatraz commented 2 years ago

1 step 1216x1216 on 8 GB VRAM with 1070 O8G 🎉 Thank You, Everyone.

If you mean you are using the code shown here with 1 step, you likely see it crash at 100%. But with the newest version of the optimization from @Doggettx, you will likely be able to successfully go that high or even higher.

camenduru commented 2 years ago

@JohnAlcatraz Yes, step 1

Now I just changed these two

https://raw.githubusercontent.com/Doggettx/stable-diffusion/main/ldm/modules/diffusionmodules/model.py https://raw.githubusercontent.com/Doggettx/stable-diffusion/main/ldm/modules/attention.py

1920x1088 with 1070 O8G 1034.58s/it https://i.imgur.com/CbIfbHp.png 🎉🎉🎉

JohnAlcatraz commented 2 years ago

1920x1088 with 1070 O8G 1034.58s/it https://i.imgur.com/CbIfbHp.png 🎉🎉🎉

1920x1088 on 8 GB VRAM is certainly impressive!

ryudrigo commented 2 years ago

There, polished it a little bit more. Now 1024px in turbo mode takes 8117 MB and 90 seconds (total) for me.

jimovonz commented 2 years ago

Anyone else finding that with increased resolution, the images are loosing coherence with multiple random occurances of the subject elements?

JohnAlcatraz commented 2 years ago

Anyone else finding that with increased resolution, the images are loosing coherence with multiple random occurances of the subject elements?

That is a known issue with stable diffusion, yes. The model was trained at 512x512 so that's the only resolution it can do very well.

jimovonz commented 2 years ago

Anyone else finding that with increased resolution, the images are loosing coherence with multiple random occurances of the subject elements?

That is a known issue with stable diffusion, yes. The model was trained at 512x512 so that's the only resolution it can do very well.

Unfortunately this seems to make most of these higher resolution images useless - unless of course you are specifically after something more abstract....

JohnAlcatraz commented 2 years ago

Unfortunately this seems to make most of these higher resolution images useless - unless of course you are specifically after something more abstract....

These optimizations are not just about being able to generate larger resolutions, but also about being able to generate the same resolution on a lower amount of VRAM, making Stable Diffusion more accessible to people with low VRAM GPUs.

ryudrigo commented 2 years ago

Indeed! I should've talked about the normal setting. Least memory usage I can get with PR #122 for 512x512 is just under 3GB VRAM

CaptnSeraph commented 2 years ago

Unfortunately this seems to make most of these higher resolution images useless - unless of course you are specifically after something more abstract....

As the img2img uses the txt2img sequence (I think) you can use lower res within txt2img to get a good seed and a good "thumbnail" and then refine larger with img2img before running through goBig and gfpgan for serious high quality and sizes (I've got photorealism at DSLR resolutions)

jimovonz commented 2 years ago

Cheers - I have been doing something similar with great results. I have been creating lower res images at 768x448 which seem to be mostly free of any obvious duplication/unwanted artifacts and then upscaling using ESRGAN up to 1920x1088 before adding more detail back in using img2img. The strength parameter needs fine tuning to get the right balance of detail - too high and you reintroduce all the same issues you were trying to avoid in the first place. 0.5 is mostly ok but sometimes you need to go lower and sometimes you can go higher with good results.

On Tue, 6 Sep 2022, 11:06 pm theseraphim, @.***> wrote:

Unfortunately this seems to make most of these higher resolution images useless - unless of course you are specifically after something more abstract....

As the img2img uses the txt2img sequence (I think) you can use lower res within txt2img to get a good seed and a good "thumbnail" and then refine larger with img2img before running through goBig and gfpgan for serious high quality and sizes (I've got photorealism at DSLR resolutions)

— Reply to this email directly, view it on GitHub https://github.com/basujindal/stable-diffusion/pull/117#issuecomment-1238002576, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABLKIJ33PRPGOOGGXNUC353V44QSTANCNFSM6AAAAAAQELKX7Y . You are receiving this because you commented.Message ID: @.***>

GordonFreeeman commented 2 years ago

Holy crap, this is actually working! I'm only a casual when it comes to python, or coding in general, but after fiddling with the above tweaks/fixes, I can generate incredibly high resolutions on my measly 6GB 1660 Ti (laptop card). Plus I have to run at full precision, because fp16 is broken exclusively on 1660 series cards.

512x512: 1.55 it/s 1024x576: 2.21 s/it 1024x1024: 6.36 s/it 1280x768: 6.51 s/it 1408x768: 7.59 s/it 1920x576: 7.74 s/it 1536x960 was working (13.80 s/it), but crashed during image export, when VRAM usage went from 5.4 GB to >6 GB

I ran all tests with 50 ddim_steps. Had to restart twice because for some reason, VRAM wasn't cleared up completely sometimes, and at higher res, it's going a little crazy with the iteration time. But it's still pretty mindblowing as a proof of concept.

tzayuan commented 5 months ago

Hi @MrLavender,

I would like to ask: SD has loaded a pretrained model, why has the implementation model of attention mechanism been changed, and the pretrained model still works correctly? Are there any techniques and areas to pay attention to in this process? thanks.