CompVis / zigma

A PyTorch implementation of the paper "ZigMa: A DiT-Style Mamba-based Diffusion Model" (ECCV 2024)
https://taohu.me/zigma
Apache License 2.0
281 stars 19 forks source link

Taking making inputs from multiple directions #18

Closed EndingCredits closed 2 months ago

EndingCredits commented 3 months ago

Rather than a purely sequential scan through the patches, I feel it would make more sense to use take inputs from multiple patches at once. Obviously you are limited in that you can induce any cyclic dependencies, but you can very easily take inputs from two neighbouring patches (e.g. directly to left and above) at once. See image.

example of the two approaches

Obviously you have the issue that this is directional, but you can just do the same idea as you did with the sequential and run four separate scans, starting from a different corner each time.

I feel like this has a number of advantages:

The onyl downside is incorporating inputs from two inputs, which is not standard Mamba practice, but architecturally this should not be difficult to implement.

If multiple inputs is unwanted, I'd also be very interested to see the result of just processing stuff in multiple parallel (disjoint) lines in each direction. Obviously this has issues that there is no direct dependency between any two patches, but I think that is resolved by using multiple layers.

What do you think?

dongzhuoyao commented 3 months ago

hi, I think your idea intuitively makes sense,

but I cannot see a clear way to tackle the downside you mentioned "The onyl downside is incorporating inputs from two inputs, which is not standard Mamba practice, but architecturally this should not be difficult to implement." as mamba only supports uni-dimension token data by default.