beichenzbc / Long-CLIP

[ECCV 2024] official code for "Long-CLIP: Unlocking the Long-Text Capability of CLIP"
Apache License 2.0
587 stars 27 forks source link

RoPE for Long-CLIP - perfect match for Flux? #67

Open zer0int opened 3 weeks ago

zer0int commented 3 weeks ago

First of all, in case anybody sees this thread and thinks "Oh, I want to use Long-CLIP with Flux!": I made a ComfyUI custom node for it, and you can find it here: https://github.com/zer0int/ComfyUI-Long-CLIP


Now, about the actual topic. The developers of Flux1 have stated in their announcement that (besides being a gigantic 12B parameters diffusion transformer), their model's superior performance is also due to rotary positional embeddings.

And indeed, the red cube on top of a blue cube problem doesn't exist for FLux. It can even accuraty generate this:

"a red cube on top of a blue cube, with a green kitten sitting on top of the red cube, the cat is holding a sign that says 'rotary positional embeddings', and in the background there are many tiny pink spheres and yellow triangles"

In terms of general spatial prompt following, Flux1 can do this (albeit details can be problematic).

Flux1 uses T5 + CLIP ViT-L/14 as Text Encoders. Besides Long-CLIP nicely complementing the maximum sequence length of T5, I naturally also wondered: What if CLIP had RoPE?

My previous MLP modification was Geometric Parametrization (GmP), which "splits" the .weight into .theta and .r, and thus preserves the learned information. Not a big deal.

However, RoPE changes how the attention mechanism works. So it needs to "learn how to see" again after this change. Nevertheless, I tried it! 🤓

First, I fine-tuned with the COCO-SPRIGHT-40k spatial labels dataset I used previously, with labels <77 tokens, to compare CLIP vs. Long-CLIP. The models show similar patterns in learning, but their validation acccuracy and F1 remains very poor after this - BUT it is improving, albeit slowly.

So I continued fine-tuning with CC12M SPRIGHT, and using the original long captions >77 tokens for Long-CLIP.

Now, after ~150,000 text-image pairs (split into two separate runs), latest:

Fine-tuned Model Accuracy on MVT ImageNet/ObjectNet: 0.79064 Down from (your model): 0.81134

So I am hoping it may be enough with "just" 1-5 million text-image pairs of re-training for CLIP to learn how to "see with its new attention".

I am using GmP and label smoothing and RoPE. GmP is necessary because I am "GPU poor", I am training on 1 RTX 4090. RoPE further increases model VRAM requirement, so I have to train using a batch size of 26 (!) - definitely NOT ideal for CLIP!

I would love to get your feedback on this idea (RoPE for CLIP in general). While GmP has shown to "just work well" empirically, I am uncertain about RoPE. So, even if you have negative feedback and criticism / if you think RoPE is a very bad idea, I would very much appreciate this feedback, too!

Any feedback is welcome! Thank you!


Supplementary images.

Long-Rope-CLIP-vs-others

long-gmp-clustering

RoPE-CATS-compare

rope-vs-gmp-heli1

rope-vs-gmp-heli2

zer0int commented 3 weeks ago

Adding the relevant / changed parts of the Long-CLIP code for RoPE.


class RoPE:
    @staticmethod
    def apply_rotary_pos_emb(q, k, cos, sin):
        q_ = (q * cos) + (RoPE.rotate_half(q) * sin)
        k_ = (k * cos) + (RoPE.rotate_half(k) * sin)
        return q_, k_

    @staticmethod
    def rotate_half(x):
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat([-x2, x1], dim=-1)

def get_rotary_embeddings(head_dim, seq_len, device):
    if head_dim % 2 != 0:
        raise ValueError(f"head_dim {head_dim} must be even to apply RoPE.")

    # Generate sin and cos tensors with the full head_dim dimension
    inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
    position_ids = torch.arange(seq_len, dtype=torch.float, device=device)
    sinusoid_inp = torch.einsum("i,j->ij", position_ids, inv_freq)

    sin, cos = torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)

    # Expand dimensions to match (batch_size, n_head, seq_len, head_dim/2)
    sin = sin.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim/2)
    cos = cos.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim/2)

    return sin, cos

def apply_rotary_pos_emb(q, k, cos, sin):
    # Split q and k into two halves along the last dimension
    q1, q2 = q.split(q.shape[-1] // 2, dim=-1)  # q1 and q2 will each have shape [batch_size, n_head, seq_len, 32]
    k1, k2 = k.split(k.shape[-1] // 2, dim=-1)

    # Print shapes to verify correctness
    #print(f"q1 shape: {q1.shape}, q2 shape: {q2.shape}")  # Expecting [257, 16, 36, 32] for both
    #print(f"k1 shape: {k1.shape}, k2 shape: {k2.shape}")  # Expecting [257, 16, 36, 32] for both

    # Apply sin and cos to the respective halves
    q_ = torch.cat([q1 * cos - q2 * sin, q1 * sin + q2 * cos], dim=-1)
    k_ = torch.cat([k1 * cos - k2 * sin, k1 * sin + k2 * cos], dim=-1)

    #print(f"q_ shape after RoPE: {q_.shape}")  # Should match original q shape [257, 16, 36, 64]
    #print(f"k_ shape after RoPE: {k_.shape}")  # Should match original k shape [257, 16, 36, 64]

    return q_, k_

class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, seq_len: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", GeometricLinear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", GeometricLinear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask
        self.seq_len = seq_len
        self.d_model = d_model
        self.n_head = n_head

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None

        # Apply layer normalization
        x_ln = self.ln_1(x)

        # Linear projections for query, key, and value
        q_proj_weight, k_proj_weight, v_proj_weight = self.attn.in_proj_weight.chunk(3)
        q = F.linear(x_ln, q_proj_weight)
        k = F.linear(x_ln, k_proj_weight)
        v = F.linear(x_ln, v_proj_weight)

        # Reshape for multi-head attention (batch_size, seq_len, n_head, head_dim)
        batch_size, seq_len, _ = q.size()
        head_dim = self.d_model // self.n_head

        q = q.view(batch_size, seq_len, self.n_head, head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.n_head, head_dim).transpose(1, 2)

        # Print shapes to debug
        #print(f"q shape after view and transpose: {q.shape}")  # Expecting (batch_size, n_head, seq_len, head_dim)
        #print(f"k shape after view and transpose: {k.shape}")  # Expecting (batch_size, n_head, seq_len, head_dim)

        # Apply RoPE to the query and key
        sin, cos = get_rotary_embeddings(head_dim, seq_len, x.device)

        # Print shapes of sin and cos
        #print(f"sin shape: {sin.shape}")  # Expecting (1, 1, seq_len, head_dim/2)
        #print(f"cos shape: {cos.shape}")  # Expecting (1, 1, seq_len, head_dim/2)

        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        # Reshape back to (batch_size, seq_len, d_model) after RoPE
        q = q.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
        k = k.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)

        # Print shapes after RoPE application and reshape
        #print(f"q shape after RoPE and reshape: {q.shape}")  # Expecting (batch_size, seq_len, d_model)
        #print(f"k shape after RoPE and reshape: {k.shape}")  # Expecting (batch_size, seq_len, d_model)

        # Perform multi-head attention
        attn_output, _ = self.attn(
            q, k, v, need_weights=False, attn_mask=self.attn_mask
        )

        return attn_output

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, seq_len: int = 248):
        super().__init__()
        self.width = width
        self.layers = layers
        self.seq_len = seq_len
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, seq_len, attn_mask) for _ in range(layers)])

    def forward(self, x: torch.Tensor):
        return self.resblocks(x)

class CLIP(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 vision_patch_size: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int, 
                 load_from_clip: bool
                 ):
        super().__init__()

        self.context_length = 248

        if isinstance(vision_layers, (tuple, list)):
            vision_heads = vision_width * 32 // 64
            self.visual = ModifiedResNet(
                layers=vision_layers,
                output_dim=embed_dim,
                heads=vision_heads,
                input_resolution=image_resolution,
                width=vision_width
            )
        else:
            vision_heads = vision_width // 64
            self.visual = VisionTransformer(
                input_resolution=image_resolution,
                patch_size=vision_patch_size,
                width=vision_width,
                layers=vision_layers,
                heads=vision_heads,
                output_dim=embed_dim
            )

        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask=self.build_attention_mask(),
            seq_len=self.context_length  # Pass the context length for RoPE
        )

        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)

        if load_from_clip == False:
            self.positional_embedding = nn.Parameter(torch.empty(248, transformer_width))
            self.positional_embedding_res = nn.Parameter(torch.empty(248, transformer_width))

        else:
            self.positional_embedding = nn.Parameter(torch.empty(77, transformer_width))

        self.ln_final = LayerNorm(transformer_width)

        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.initialize_parameters()
        self.mask1 = torch.zeros([248, 1])
        self.mask1[:20, :] = 1
        self.mask2 = torch.zeros([248, 1])
        self.mask2[20:, :] = 1

Note, in case you think it's interesting and want to try it (with a more suitable batch size, perhaps): As SPRIGHT-T2I is GPT-4V etc. / AI labeled, I noticed the dataset contains a few "glitch labels", where the label is just a repetition of "a small airplane and a large airplane. a small airplane and a large airplane. [etc] [etc]", exceeding 248 tokens. I recommend running the entire dataset against the tokenizer and just deleting any examples that cause an error. I find about 10 out of 100,000 so far - and they are often low-quality images (e.g. images full of text), so I think it's best to just auto-delete them.