basujindal / stable-diffusion

Optimized Stable Diffusion modified to run on lower GPU VRAM
Other
3.14k stars 469 forks source link

Small optimizations to CPU utilization #130

Open ArneBab opened 2 years ago

ArneBab commented 2 years ago

I tweaked the CPU code to reduce runtime by about 20%. This is not ready for merge, because it relies on my local CPU cores and only works with hyperthreading, but I wanted to share it anyway:

diff --git a/optimizedSD/optimized_txt2img.py b/optimizedSD/optimized_txt2img.py
index 6ead861..3a795a3 100644
--- a/optimizedSD/optimized_txt2img.py
+++ b/optimizedSD/optimized_txt2img.py
@@ -34,8 +34,16 @@ def load_model_from_config(ckpt, verbose=False):

 config = "optimizedSD/v1-inference.yaml"
ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
+# try allowing hyperthreading: 6 physical CPUs => 12 virtual cores, use one less
+# assumes that the CPUs are not used optimally, might actually slow the system down
+# but it is actually faster: at 11 threads 16s instead of 22s for an iteration for two samples.
+# but 11 significantly slows down the computer.
+torch.set_num_threads(10)
+# limit inter_op threads to the physical CPUs to have more for intra-op (which on my CPUs hopefully has better caching)
+# This gets the time per iteration down to 15s
+torch.set_num_interop_threads(5)

 parser = argparse.ArgumentParser()
basujindal commented 2 years ago

Hi, thank you for sharing this optimization, I don't know a lot about hyperthreading, so will it be possible for you to write modifications for a general CPU architecture if that is even possible? Thanks!

ArneBab commented 2 years ago

Hi, thank you for your answer! I don’t think I can actually write this for a general architecture right now (I don’t know the Python-APIs well enough to know where to find the number of virtual and physical cores). I hope that this here can give someone with the experience with the API the required pointers.

What I basically did: torch.set_num_threads (int(0.8*virtual_cores)) and torch.set_num_interop_threads(int(0.8*physical_cores)).

ArneBab commented 2 years ago

The 0.8 is just empirical.

ArneBab commented 2 years ago

Sidenote: I run this with a file of prompts (TODO_prompts.txt) and then call this:

cat TODO_prompts.txt | xargs -I {} nice -n 2 python optimizedSD/optimized_txt2img.py --device cpu --precision full --prompt "{}" --H 512 --W 512 --n_iter 1 --n_samples 2 --ddim_steps 75

konimaki2022 commented 2 years ago

Not bad! I put this at the beginning of txt2img_gradio.py and everything runs much faster.

torch.set_num_threads(os.cpu_count())
torch.set_num_interop_threads(os.cpu_count())

It retuns logical cpus (threads), in my case I have not HT.

ArneBab commented 2 years ago

How did you check that it runs faster? In my case the full CPU couont just consumed more CPU but was slower. (please benchmark the full creation! I simple time python optimizedSD/... should be enough to get an idea; you’ll want to repeat that to get better info)

In my case I had to reduce the number of CPUs because using the full number was actually slower (I guess that it was competing too much with other processes on my system and maybe with itself).

konimaki2022 commented 2 years ago

To test only cpu i run time python3 optimizedSD/optimized_txt2img.py --prompt "david beckam, oil_painting, headshot" --H 512 --W 512 --n_iter 1 --n_samples 1 --ddim_steps 10 --turbo --precision full --device cpu

Best times of some attempts:

torch.set_num_threads(os.cpu_count()-1)
torch.set_num_interop_threads(os.cpu_count()-1)

real    4m25,051s
user    12m47,141s
sys  1m35,402s
torch.set_num_threads(os.cpu_count())
torch.set_num_interop_threads(os.cpu_count())

real    3m51,132s
user    11m39,489s
sys  1m16,910s

Well, I have an i5 2500k with 4 cores without HT, 1 core less is like a drop in performance of up to 25% for each task, that's why I say that with all the cores it runs much faster in my case, and in gradio ui there is an animation progress that wants to take a cpu core if i leave one idle.

I think it could be an optional parameter and leave the default as it is to take all cores.

ArneBab commented 2 years ago

Does os.cpu_count() return 4 in your case? Then this would be expected. I have 6 physical CPU cores, but set 10, because hyperthreading allows the CPUs limited optimizations in cases where one process on the CPU would have idle time because the parallelism of the code isn’t an exact match to the possibilities in the chip.

I reduce by 20%, because hyperthreading can overcommit CPUs and then the processes can block each other.

I use only the physical cores -1 for interop_threads (so 5 instead of 6), because I guess (yes, guess) that the inner-operation multithreading can better utilize virtual CPUs that actually run on the same hardware so they have the same caches.

bitRAKE commented 1 year ago

One way to only adjust threading on hyperthreaded systems is:

# get_num_threads defaults to physical cores, while os.cpu_count reports
# logical cores.  Only adjust thread count on hyperthreaded systems:
if opt.device == "cpu" and torch.get_num_threads() != os.cpu_count:
    torch.set_num_threads(int(os.cpu_count()*0.8))
    torch.set_num_interop_threads(int(os.cpu_count()*0.8))

Works on windows - which I was worried about, should work on linux as well. I'm only seeing about a 10% increase over just physical cores, on Ryzen 5950X -- seems to be memory bound.

ArneBab commented 1 year ago

@bitRAKE Thank you! I would keep the interop-threads lower. These are likely not operating on the same memory regions, so they do not benefit as much from potentially shared caching when on the same physical CPU.

My setup would rather be:

# get_num_threads defaults to physical cores, while os.cpu_count reports
# logical cores.  Only adjust thread count on hyperthreaded systems:
if opt.device == "cpu" and torch.get_num_threads() != os.cpu_count:
    physical_cores = torch.get_num_threads()
    torch.set_num_threads(int(os.cpu_count()*0.8))
    # reduced interop-threads to leave one physical CPU for other tasks like filesystem IO
    torch.set_num_interop_threads(int(physical_cores*0.8))
bitRAKE commented 1 year ago

Unfortunately, I'm seeing drastic memory thrashing under Windows - massive allocation swings of 10+GB and this will consume most of the time on larger images. Pytorch should (imho) maintain it's own memory pool. When memory is released to Windows it wants to clear the memory before any application can have it - which is an absurd security requirement. Since pytorch is just going to use the memory again it should hold on to it. I'll need to research further to see if settings exist to prevent this kind of memory thrashing. I'm new to all this python stuff, but motivated. :)