pamparamm / sd-perturbed-attention

Perturbed-Attention Guidance and Smoothed Energy Guidance for ComfyUI and SD Forge
MIT License
213 stars 14 forks source link

Would it be possible to get this node working with TensorRT models? #19

Closed ManOrMonster closed 3 months ago

ManOrMonster commented 4 months ago

I've noticed that PAG does not work with TensorRT models. Is this feasible? The extra speed boost would be nice to use with PAG, and some models just don't have great output without it.

pamparamm commented 4 months ago

Seems like it's not possible in a memory-efficient way since attention in TRT engines can't be patched by ModelPatcher. I've implemented support for TRT and PAG locally, but it consumes like x2 VRAM - to use it you would need to:

  1. Build static/dynamic TRT engine of a desired model.
  2. Build static/dynamic TRT engine of the same model with the same TRT parameters, but with fixed PAG injection in selected UNET blocks. engines
  3. Use special TensorRT Perturbed-Attention Guidance with two model inputs: one for base engine and one for PAG engine pag-trt

This SDXL setup from screenshots consumes ~15 GB of VRAM. If you are still interested in this method - I'll push it to upstream.

ManOrMonster commented 4 months ago

I have a 4080... I'd love to give it a try.

pamparamm commented 4 months ago

I've pushed changes in 59f1e2f, feel free to test

ManOrMonster commented 4 months ago

Nice! About to build the PAG model. My original TRT model has a LORA baked in. Do I need to also bake that into the PAG model?

pamparamm commented 4 months ago

Do I need to also bake that into the PAG model?

Yes, workflows for creating engines should be equivalent (except for TensorRT Attach PAG node)

ManOrMonster commented 4 months ago

Crazy stuff.

Getting OOM errors, even after it went into low VRAM mode. I saved the CLIP and VAE from the model and loaded those instead of the model (for CLIP and VAE)... figured I'd save ~6.46 GB of VRAM that way, but it didn't work that way for some reason!

Changed to Vae Decode (Tiled). This stopped the OOM, but it only generates black images.

I feel like I must be doing something wrong but not sure what. I'll keep plugging away.

I also get these errors in the console:

[07/10/2024-18:25:33] [TRT] [E] 3: [executionContext.cpp::nvinfer1::rt::ShapeMachineContext::resolveSlots::2842] Error Code 3: API Usage Error (Parameter check failed at: runtime/api/executionContext.cpp::nvinfer1::rt::ShapeMachineContext::resolveSlots::2842, condition: allInputDimensionsSpecified(routine) )
ManOrMonster commented 4 months ago

Ok, started over, kept it simple. Loaded the checkpoint for VAE and CLIP, used tiled VAE decode, but the results are worse than using just the checkpoint + PAG, and take much longer to generate (over twice as long)..

Edit: The results weren't actually worse, just tested again. But definitely took about twice as long to generate for my setup.

pamparamm commented 3 months ago

Half the usual speed is expected for PAG. I've got ~5 it/s with normal PAG and ~8 it/s with TRT PAG. Or by 'twice as long to generate' you mean the difference between normal PAG and TRT PAG (with TRT PAG being slower)?

ManOrMonster commented 3 months ago

I meant it takes over twice as long when using the TRT PAG. But it goes into low vram mode using this method, so that's very likely why. I also have to use tiled VAE decode or it gets OOM. Do you have a 4090?

pamparamm commented 3 months ago

Yeah, I have 24GB of VRAM. Judging from my own tests, this method seems suitable only for 24GB+ videocards :(.

I haven't found any other way of implementing PAG with TensorRT support (TRT really restricts python's 'monkeypatch everything' magic) so I'll leave this method and close this issue as a complete.