huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.12k stars 5.38k forks source link

Can we allow making everything on gpu/cuda for scheduler? #9485

Open xiang9156 opened 1 month ago

xiang9156 commented 1 month ago

What API design would you like to have changed or added to the library? Why? Is it possible to allow setting every tensor attribute of scheduler to cuda device? In https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_lcm.py It looks like that attributes like scheduler.alphas_cumprod are tensors on cpu, but the scheduler.set_timesteps() allows setting scheduler.timesteps to gpu/cuda device. Isn't this causing device mismatch when indexing scheduler.alphas_cumprod with scheduler.timesteps? Below is the code snippet that the pipline is indexing a cpu tensor(alphas_cumprod) with a gpu tensor(timestep) image I simply added following lines to print the timestep and self.alphas_cumprod type and device at the begining of the scheduler.step()

print("Printing scheduler.step() timestep")
print(type(timestep))
print(isinstance(timestep, torch.Tensor))
print(timestep.device)
print("Printing scheduler.step() self.alphas_cumprod")
print(type(self.alphas_cumprod))
print(isinstance(self.alphas_cumprod, torch.Tensor))
print(self.alphas_cumprod.device)

Output when running text-to-image:

Printing scheduler.step() timestep
<class 'torch.Tensor'>
True
cuda:0
Printing scheduler.step() self.alphas_cumprod
<class 'torch.Tensor'>
True
cpu

What use case would this enable or better enable? Can you give us a code example? We are using a modified LCMScheduler (99% same as the original LCMScheduler) for video generations, it's generating frames repeatedly in a loop. for most of the time, this step doesn't cause performance issue. But we did see intermittent high cpu usage and latency for alpha_prod_t = self.alphas_cumprod[timestep]. And from torch.profiler and tracing output, it. shows high latency for this specific step. We are wondering if this is the performance bottleneck. image

asomoza commented 1 month ago

cc: @yiyixuxu

yiyixuxu commented 1 month ago

thanks for the issue! and interesting that this performance issue is not deterministic cc @a-r-r-o-w here, I think cogvideox scheduler has a similar design and he is investigating for a similar issue

note that most of our popular schedulers (euler/dpm), step calculation only replied on sigmas (which we put on cpu and use an int to index, which seems to be the best option performance-wise, with and without torch.compile https://github.com/huggingface/diffusers/pull/6173)

we can not initiate everything on GPU because the user device is not known at the point of initiation

# scheduler.__init__ is called at this step
pipe = DiffusionPipeline.from(...) 
# user can then enable offloading or move to a device
pipe.enable_model_cpu_offload() # or pipe.to("cuda")
# we wouldn't know the device until the pipeline is called 
out = pipe(...)

maybe we see if we can update the calculation to be based on sigma?

xiang9156 commented 1 month ago

thanks! the official diffusers document loads LCMScheduler in this way image can we have extra arg like device to make it initialize with passed in device?

just curious, if we keep everything on cpu for scheduler, wouldn't unet model implicitly convert timestep into tensor on gpu device during inference?

thanks for the issue! and interesting that this performance issue is not deterministic cc @a-r-r-o-w here, I think cogvideox scheduler has a similar design and he is investigating for a similar issue

note that most of our popular schedulers (euler/dpm), step calculation only replied on sigmas (which we put on cpu and use an int to index, which seems to be the best option performance-wise, with and without torch.compile #6173)

we can not initiate everything on GPU because the user device is not known at the point of initiation

# scheduler.__init__ is called at this step
pipe = DiffusionPipeline.from(...) 
# user can then enable offloading or move to a device
pipe.enable_model_cpu_offload() # or pipe.to("cuda")
# we wouldn't know the device until the pipeline is called 
out = pipe(...)

maybe we see if we can update the calculation to be based on sigma?

a-r-r-o-w commented 1 month ago

Hi @xiang9156. I shared my investigation in this comment. Do your observations align with it? We're thinking about how to remove all cuda stream synchronization from the pipeline (which should speed things up quite a bit across all pipelines), so if you have any other insights from your own investigation(s), please feel free to chime in and we can jam on implementing the improvements together.

Another area for improvement could be removing unnecessary calls to cudaMemCpy which seem like the next obvious bottleneck to tackle after fighting cudaStreamSync's.

xiang9156 commented 1 month ago

Hi @a-r-r-o-w , i saw your investigation, thanks for doing that!! I'll also try moving everythong onto cpu/gpu on my end to see the performance gain.

xiang9156 commented 1 month ago

Hi @xiang9156. I shared my investigation in this comment. Do your observations align with it? We're thinking about how to remove all cuda stream synchronization from the pipeline (which should speed things up quite a bit across all pipelines), so if you have any other insights from your own investigation(s), please feel free to chime in and we can jam on implementing the improvements together.

Another area for improvement could be removing unnecessary calls to cudaMemCpy which seem like the next obvious bottleneck to tackle after fighting cudaStreamSync's. @a-r-r-o-w

i just tried very simple script on Google Colab to compare 4 cases:

It did show some useful information to me: indexing a gpu tensor with gpu tensor scalar is super slow... i didn't get time to do torch profiling but i guess it might be caused by synchronization, i may do that later today. image

xiang9156 commented 1 month ago

Hi @xiang9156. I shared my investigation in this comment. Do your observations align with it? We're thinking about how to remove all cuda stream synchronization from the pipeline (which should speed things up quite a bit across all pipelines), so if you have any other insights from your own investigation(s), please feel free to chime in and we can jam on implementing the improvements together. Another area for improvement could be removing unnecessary calls to cudaMemCpy which seem like the next obvious bottleneck to tackle after fighting cudaStreamSync's. @a-r-r-o-w

i just tried very simple script on Google Colab to compare 4 cases:

  • indexing a cpu tensor with gpu tensor scalar
  • indexing a cpu tensor with cpu tensor scalar (Fast)
  • indexing a gpu tensor with gpu tensor scalar
  • indexing a gpu tensor with cpu tensor scalar(Fast)

It did show some useful information to me: indexing a gpu tensor with gpu tensor scalar is super slow... i didn't get time to do torch profiling but i guess it might be caused by synchronization, i may do that later today. image

i tried profiling our own scheduler after setting timesteps on cpu, the scheduler.step() latency dopped to us level(hundreds of us). Without setting timesteps on cpu, the scheduler.step() latency was ranging from 3ms to 100+ms(happened intermittently for video generation). i'm trying to find any side effect of setting timesteps on cpu

xiang9156 commented 1 month ago

@a-r-r-o-w from tracing, i did see extra sync in controlnet forward and unet forward after i set timestep tensor to cpu

image image

a-r-r-o-w commented 1 month ago

I think it's better to maintain two copies of timesteps in the scheduler for this (one on cuda and one on cpu). To the controlnet/unet/transformer, you can pass the cuda timestep but to the scheduler, you can pass the cpu timestep. This makes it so that there are no cuda synchronizations at all until completion of inference.

However, I'm only now noticing that even if we save time from cuda synchronization by using above method, we will have to sync at some point in time before/after decoding. This causes roughly the same inference time overall leading to almost no performance gains. I'll post a more detailed comment soon in the PR

xiang9156 commented 1 month ago

I think it's better to maintain two copies of timesteps in the scheduler for this (one on cuda and one on cpu). To the controlnet/unet/transformer, you can pass the cuda timestep but to the scheduler, you can pass the cpu timestep. This makes it so that there are no cuda synchronizations at all until completion of inference.

However, I'm only now noticing that even if we save time from cuda synchronization by using above method, we will have to sync at some point in time before/after decoding. This causes roughly the same inference time overall leading to almost no performance gains. I'll post a more detailed comment soon in the PR

Yes, i'm also setting two copies of timesteps. the sync did disappear.

xiang9156 commented 1 month ago

@a-r-r-o-w i think i know why my cudaMemcpyAsync takes long time, it's still related to the underlying asyn cuda execution, cudaMemcpyAsync is synchronizing until the underlying async cuda execution finishes. that's why it could take 100+ms at most. After keeping 2 copies of timestep, and do more profiling, this was revealed...

github-actions[bot] commented 3 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.