CompVis / zigma

A PyTorch implementation of the paper "ZigMa: A DiT-Style Mamba-based Diffusion Model" (ECCV 2024)
https://taohu.me/zigma
Apache License 2.0
250 stars 15 forks source link

Taking making inputs from multiple directions #18

Open EndingCredits opened 3 weeks ago

EndingCredits commented 3 weeks ago

Rather than a purely sequential scan through the patches, I feel it would make more sense to use take inputs from multiple patches at once. Obviously you are limited in that you can induce any cyclic dependencies, but you can very easily take inputs from two neighbouring patches (e.g. directly to left and above) at once. See image.

example of the two approaches

Obviously you have the issue that this is directional, but you can just do the same idea as you did with the sequential and run four separate scans, starting from a different corner each time.

I feel like this has a number of advantages:

The onyl downside is incorporating inputs from two inputs, which is not standard Mamba practice, but architecturally this should not be difficult to implement.

If multiple inputs is unwanted, I'd also be very interested to see the result of just processing stuff in multiple parallel (disjoint) lines in each direction. Obviously this has issues that there is no direct dependency between any two patches, but I think that is resolved by using multiple layers.

What do you think?

dongzhuoyao commented 3 weeks ago

hi, I think your idea intuitively makes sense,

but I cannot see a clear way to tackle the downside you mentioned "The onyl downside is incorporating inputs from two inputs, which is not standard Mamba practice, but architecturally this should not be difficult to implement." as mamba only supports uni-dimension token data by default.