huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.05k stars 5.18k forks source link

Found 3 more it/s for Euler_a to hit 50.73 it/s on a 4090. #3950

Closed aifartist closed 1 year ago

aifartist commented 1 year ago

Describe the bug

I was helping vladmandic test his diffusers integration into sdnext and found small changes to hit 47.73 it/s in his dev branch. I decided to do a py profile, found the nonzero() function standing out a lot, and found it being called in scheduling_euler_ancestral_discrete.py in two places where it does:

step_index = (self.timesteps == timestep).nonzero().item()

There are other calls to nonzero which can also be optimized for other workloads but I'm focusing on basic image inferencing. For a 20 step euler_a inference self.timesteps is:

tensor([999.0000, 946.4211, 893.8421, 841.2632, 788.6842, 736.1053, 683.5263, 630.9474, 578.3684, 525.7895, 473.2105, 420.6316, 368.0526, 315.4737, 262.8947, 210.3158, 157.7368, 105.1579, 52.5789, 0.0000], device='cuda:0', dtype=torch.float64)

At each step, a 'timestep' is passed with each of these same values one at a time to the functions scale_model_input() and step(). IOW, 999 is passed for step 0 and of course it finds it at [0] in the tensor and sets step_index to 0. In then does this over and over again for each step using a slow nonzero() function to do this when a global step_index and "+1" could have been used. I'm not saying this is the best way but that is what I tested and the perf diff was quite large at an additional 3 it/s.

ncalls tottime percall cumtime percall filename:lineno(function) 200 1.223 0.006 1.223 0.006 {method 'nonzero' of 'torch._C._TensorBase' objects} 5 0.150 0.030 0.150 0.030 {method 'cpu' of 'torch._C._TensorBase' objects}

NOTE: Others have complained about the performance of nonzero() before. One example: https://discuss.pytorch.org/t/how-to-make-torch-nonzero-faster/119993

I have a 4090 and i9-13900K on Ubuntu 22.04.

Reproduction

Here is a script which runs a diffusers pipeline.

import torch
from diffusers import DiffusionPipeline
from diffusers import EulerAncestralDiscreteScheduler

import cProfile
import pstats
import io
from pstats import SortKey

path = 'stabilityai/stable-diffusion-2-1-base'
prompt = "Women standing on a mountain top"

torch.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True

with torch.inference_mode():
    pipe = DiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False)
    pipe.to('cuda')
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
    pipe.unet.to(device='cuda', dtype=torch.float16, memory_format=torch.channels_last)

    for bi in range(7):
        if bi == 2:      # Start profiler on 3rd image
            ob = cProfile.Profile()
            ob.enable()
        images = pipe(prompt=prompt, width=512, height=512, num_inference_steps=20, num_images_per_prompt=1).images
    ob.disable()
    sec = io.StringIO()
    sortby = SortKey.TIME
    ps = pstats.Stats(ob, stream=sec).sort_stats(sortby)
    ps.print_stats()
    print(sec.getvalue()[0:1000])       # Close enough

Logs

No response

System Info

Who can help?

No response

vladmandic commented 1 year ago

my $0.02...

this impacts quite a lot of schedulers, not just euler a - seems that code is just carried over from original implementation

anyhow, most schedulers build interpolated timesteps based on desired number of inference steps
and it does that incrementally by checking which steps are already processed by finding index of first non-equal element using torch.nonzero()

in theory, this guards agaist borderline cases to prevent scheduler getting stuck, but if scheduler is initialized with sufficiently high num_train_timesteps (default is 1000), that is not a problem

bigger problem is that scheduler is stateless once initialized, so it doesn't know which is the last step it processed
and adding state means that state needs to be reinitialized somewhere (and definitely without requiring usercode changes) perhaps by keeping steps_index as self.steps_index and initializing it when num_train_timesteps - 1 == timesteps (need to check math if starting diff is always 1)?

and a simpler quick workaround:
it seems that torch.nonzero() is much slower on gpu than on cpu (according to torch itself),
why not perform timestep = timestep.to(self.timesteps.device) after call to nonzero instead of just before.

patrickvonplaten commented 1 year ago

Very cool issue :heart: Investigating!

jfischoff commented 1 year ago

@vladmandic I think some of the samplers this affects are stateful, such as UniPCMultistepScheduler (and I think all the multistep samplers) which stores a last_sample and using nonzero as described.

Assuming my understanding is correct, I wonder if the fix for the stateful samplers is to increment the step_index like @aifartist attempted.

aifartist commented 1 year ago

bigger problem is that scheduler is stateless once initialized, so it doesn't know which is the last step it processed and adding state means that state needs to be reinitialized somewhere (and definitely without requiring usercode changes) perhaps by keeping steps_index as self.steps_index and initializing it when num_train_timesteps - 1 == timesteps (need to check math if starting diff is always 1)?

I'm not sure what this means. There is self.is_scale_input_called as one example of state. Now that I've discovered where the two functions are called in pipeline_stablediffusion.py:__call_\() it seems that it'd be clean to keep the step_index in the class. I do need to understand the "warmup steps" stuff which I've never encountered before.

patrickvonplaten commented 1 year ago

@yiyixuxu could you give this PR a try? As mentioned by @vladmandic here: https://github.com/huggingface/diffusers/issues/3950#issuecomment-1621765217 we can probably speed-up all of our k-diffusion samplers. Could you try to run some benchmarking on our end and see how we can speed it up? :-)

patrickvonplaten commented 1 year ago

@yiyixuxu let me know if you need help here

kevint324 commented 1 year ago

@yiyixuxu could you give this PR a try? As mentioned by @vladmandic here: #3950 (comment) we can probably speed-up all of our k-diffusion samplers. Could you try to run some benchmarking on our end and see how we can speed it up? :-)

@patrickvonplaten I'm suffering from same issue caused by the non-zero. Where is the PR you are talking about? I'd like to give it a try.

Thanks a million

aifartist commented 1 year ago

I tested the PR. it/s for before and after doing a 512x512 image with 20 steps on the base sd2.1 model. This is on my 4090 running on Ubuntu 22.04 with a i9-13900K. BEFORE:

100%|█████████████| 20/20 [00:00<00:00, 48.64it/s]
100%|█████████████| 20/20 [00:00<00:00, 48.63it/s]
100%|█████████████| 20/20 [00:00<00:00, 48.63it/s]
100%|█████████████| 20/20 [00:00<00:00, 48.64it/s]

AFTER:

100%|█████████████| 20/20 [00:00<00:00, 51.50it/s]
100%|█████████████| 20/20 [00:00<00:00, 51.53it/s]
100%|█████████████| 20/20 [00:00<00:00, 51.46it/s]
100%|█████████████| 20/20 [00:00<00:00, 51.44it/s]
aifartist commented 1 year ago

@patrickvonplaten I believe I found something even bigger with this which increases performance by 36.3 it/s! Really. :-) With both your 'graph break' fix yesterday and this fix torch.compile went from 62.8 it/s to 99.1 it/s. If I use 0.19.1 I get 62. If I use 0.19.2 without a fix for this bug I also get 62. But with both I get 99. I've repeatedly confirmed this. FYI, I'm using 'max-autotune'

Because of this I wonder if the priority of getting this out increases and whether we should look at other samplers?

aifartist commented 1 year ago

@patrickvonplaten I truly believe that as the performance has improved the it/s is beginning to become useless. Even if you take out the post processing like upscaling, image saving, etc. the it/s do not correlate to the actual image generation time. With batchsize 3 I get up to an effective 174 it/s(58 times 3). Yet the image gen pipeline time is only 300ms when it should be much faster considering the 450ms time when I'm closer to 60 it/s.

I need to do more debugging but for people with 4090's reporting "ZERO" seconds as ?tqdm? does isn't useful and the it/s smell funny when things get this fast.

aifartist commented 1 year ago

@patrickvonplaten I have found the problem with it/s and it is a real problem which does get worse as things get faster. The short version:

1) tqdm is measuring performance for what is an async operation(The multistep UNet for StableDiffusion). 2) There can still be 100ms of work left to complete at the end of the last UNet step. 3) As the time for the async UNet portion continues to drop the it/s zoom up but the work isn't really finished.

When running slower without compilation the remaining time after the last step might only be 18ms. Of course, this also skews the results but 18 is a small number in proportion to something like UNet finishing in 400ms instead of 418ms. But with compilation things are pushed even faster and it appears that the post UNet remaining time is even more significant. Also, this might make VAE look like the problem because it has to eat the cost of the incomplete async work. Yes, compile makes the async portion of the UNet drop from 400ms to 200ms but who cares when reality is closer to 300ms.

torch.compile does make things faster but things need to be measured correctly to characterize HOW MUCH faster. I always modify any test harness I use to add a true timing of the Diffusers pipeline.

outrun32 commented 1 year ago

@patrickvonplaten I believe I found something even bigger with this which increases performance by 36.3 it/s! Really. :-) With both your 'graph break' fix yesterday and this fix torch.compile went from 62.8 it/s to 99.1 it/s. If I use 0.19.1 I get 62. If I use 0.19.2 without a fix for this bug I also get 62. But with both I get 99. I've repeatedly confirmed this. FYI, I'm using 'max-autotune'

Because of this I wonder if the priority of getting this out increases and whether we should look at other samplers?

I've been trying to replicate your results on my machine with A100. It doesn't seem to work for me and I get 61 it/s (which is still decent but not as major improvement as yours). Now, I'm not sure if this is a GPU thing or I'm doing something wrong, but either way, I would be glad if you shared more details like the code you used.

aifartist commented 1 year ago

@patrickvonplaten I believe I found something even bigger with this which increases performance by 36.3 it/s! Really. :-) With both your 'graph break' fix yesterday and this fix torch.compile went from 62.8 it/s to 99.1 it/s. If I use 0.19.1 I get 62. If I use 0.19.2 without a fix for this bug I also get 62. But with both I get 99. I've repeatedly confirmed this. FYI, I'm using 'max-autotune' Because of this I wonder if the priority of getting this out increases and whether we should look at other samplers?

I've been trying to replicate your results on my machine with A100. It doesn't seem to work for me and I get 61 it/s (which is still decent but not as major improvement as yours). Now, I'm not sure if this is a GPU thing or I'm doing something wrong, but either way, I would be glad if you shared more details like the code you used.

Did you upgrade diffusers to 0.19.2 AND apply the fix to this bug that you can find in the PR? Warning the 99.1 it/s seems like 90 or more percent faster but that actual image generation time is only 20% faster. Diffusers is failing to do a final tqdm.clear at the end which would take into account the VAE and things like the NSFW check. Technically the "iteration" are only the steps of the UNet but people have been relying on it/s for an idea of their perf. A1111 handles this with the tqdm.clear() after the entire process is finish. Hence it/s correlates with time. Personally I would do away with tqdm and just report some stats like "{n} images in %f seconds".

patrickvonplaten commented 1 year ago

Hey @aifartist,

it would be amazing to see a PR that can achieve such speed ups :-)

aifartist commented 1 year ago

@patrickvonplaten

Hey @aifartist,

it would be amazing to see a PR that can achieve such speed ups :-)

As I mentioned above it/s turns out to be quite misleading as a measurement of image generation performance. tqdm doesn't even accurately measure "iteration" performance much less image generation time. A1111, but not diffusers, accounts for this by calling tqdm close() at the very end of the ENTIRE processes. But I used a basic diffusers pipeline and this gave me the very misleading 99 it/s. So instead of a 90% perf increase, when torch.compile was, used it was only 21%.

I believe that what is happening here is that both a combination of your graph break fix and what I found here exacerbate the problem which was present previously but only now stood out enough for me to notice it. While these things do speed things up, the also increase the amount of incomplete async work left over at the exit from the UNet steps within:

with self.progress_bar(total=num_inference_steps) as progress_bar:

I have never understood why the powers that be decided for the average artist that the geeky internals of denoising steps per second was the best way to report performance particularly because more GPU work still needs to be done(VAE, NSFW check, etc.)

patrickvonplaten commented 1 year ago

I think we can close this with #4347 being merged no?

vladmandic commented 1 year ago

I believe so