CompVis / zigma

A PyTorch implementation of the paper "ZigMa: A DiT-Style Mamba-based Diffusion Model" (ECCV 2024)
https://taohu.me/zigma
Apache License 2.0
281 stars 19 forks source link

How to implement torch.compile for Mamba models? #6

Closed yyNoBug closed 6 months ago

yyNoBug commented 7 months ago

Hi, I notice from your README file that torch.compile provides a great speedup. However, I didn't see where you implemented torch.compile for your train_acc.py. I tried to add torch.compile(model) for models containing Mamba blocks, but it causes some errors. May I know how you implemented torch.compile for your zigzag model? Thanks!

dongzhuoyao commented 7 months ago

torch.compile is mainly used for indexing operation for the zigzag path, not for the whole model.

see https://github.com/CompVis/zigma/blob/1e78944ebce400d34a12efd4baba1daad0fae9f3/dis_mamba/mamba_ssm/modules/mamba_simple.py#L55

and

https://github.com/CompVis/zigma/blob/1e78944ebce400d34a12efd4baba1daad0fae9f3/dis_mamba/mamba_ssm/modules/mamba_simple.py#L60