Open Birch-san opened 3 months ago
okay I was able to get this working by putting my control flow-dependent operations into a @script_if_tracing
subroutine.
https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/
this is awesome; can trace the UNet just once, and have small sections which activate functionality at runtime based on "if my_cool_kwarg is None". this means that we don't pay for the cost of unused optional functionality (we're not forced to send an all-zeros batch and do lots of no-op maths).
and sfast's register_custom_python_operator
was a lifesaver for wrapping operations to survive script-mode JIT.
the other way to approach this would've been "always activate optional functionality, but send it a batch-of-zero". but I wasn't able to try that approach because torch sdpa + Flash Attn currently rejects batch-of-zero.
https://github.com/pytorch/pytorch/issues/133780
thanks for the great work on stable-fast. it compiles quickly and boosts speed a lot.
is it possible to support two different compilation graphs?
for example swapping the SDXL UNet's AttentionProcessor to do a different algorithm depending on what kind of request the user sends us.
the current workaround is to "always run all optional functionality" and just pass zeroes for anything we're not using. but this isn't free.
so is there a way to compile the UNet with its default AttentionProcessor, then apply a different AttentionProcessor to it, and compile it again?
and after that whenever a user sends a request: we'd just apply whichever AttentionProcessor is appropriate, and it'll use either of the graphs it compiled previously?