feizc / Dimba

Transformer-Mamba Diffusion Models
75 stars 5 forks source link

Question about the Dimba architecture #6

Open Hongjiew opened 2 months ago

Hongjiew commented 2 months ago

Thank you for your excellent work! However, I raised some questions regarding the Dimba architecture after reading the paper and the code.

TL;DR: the code seems to implement Dimba as Self-Attention -> Cross-Attention -> FFN -> BiMamba, which is different from Fig. 2 in the paper, indicating the architecture is Self-Attention -> Cross-Attention -> BiMamba -> FFN. In addition, it seems the number of SA, CA and FFN blocks is always equal or larger than the number of BiMamba blocks, which is different from Fig. 2 in the paper, indicating the possibility of stacking multiple BiMamba blocks with a SA-CA block.

In the definition of Dimba, blocks are wrapped up as follows (dimba.py L388-394):

for i in range(self.depth): 
    x = auto_grad_checkpoint(self.blocks[i], x, y, t0, y_lens)  # (N, T, D) #support grad checkpoint 
    if self.gap == 1:
        x, residual = auto_grad_checkpoint(self.mamba_blocks[i], x, residual)
    else:
        if i % self.gap == 1: 
            x, residual = auto_grad_checkpoint(self.mamba_blocks[i // self.gap], x, residual)

self.blocks is defined by a list of DimbaBlock (dimba.py L306-315), and DimbaBlock wraps up Self-Attention, Cross-Attention, and FFN following the order of SA -> CA -> FFN (dimba.py L84-87). self.mamba_blocks is defined by the function create_block() and is actually a list of Block (dimba.py L171-201), and each Block ends with a Mamba block. The exhibited code above seems to indicate that:

  1. Each block in self.mamba_blocks always follows a block in self.blocks. It seems to indicate the architecture is SA -> CA -> FFN -> BiMamba, instead of SA -> CA -> BiMamba -> FFN, as shown in Fig. 2 of the paper.
  2. If self.gap > 1, some blocks in self.blocks will not be followed by a Mamba block, which means the number of SA, CA, FFN blocks is larger than the number of Mamba blocks. Even if self.gap = 1, there will only be one Mamba block following the DimbaBlock. It seems that Mamba blocks will never be more than SA, CA, FFN blocks, which is different from the case indicated by Fig. 2 in the paper.

Thank you so much for your attention to this. It would be great if you could clarify the architecture of Dimba. Please let me know if my understanding of the code has errors.

feizc commented 2 months ago

Hi, thanks for your careful reviewing. We will correct the order of layers in the Figure in paper.

For the setting of layers number, the current code is used for subsequent ablation experiments. I think that simply exchange define of the Mamba and attention layer can support more Mamba layers ratio :)