Open zer0int opened 3 weeks ago
Adding the relevant / changed parts of the Long-CLIP code for RoPE.
class RoPE:
@staticmethod
def apply_rotary_pos_emb(q, k, cos, sin):
q_ = (q * cos) + (RoPE.rotate_half(q) * sin)
k_ = (k * cos) + (RoPE.rotate_half(k) * sin)
return q_, k_
@staticmethod
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def get_rotary_embeddings(head_dim, seq_len, device):
if head_dim % 2 != 0:
raise ValueError(f"head_dim {head_dim} must be even to apply RoPE.")
# Generate sin and cos tensors with the full head_dim dimension
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
position_ids = torch.arange(seq_len, dtype=torch.float, device=device)
sinusoid_inp = torch.einsum("i,j->ij", position_ids, inv_freq)
sin, cos = torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
# Expand dimensions to match (batch_size, n_head, seq_len, head_dim/2)
sin = sin.unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim/2)
cos = cos.unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim/2)
return sin, cos
def apply_rotary_pos_emb(q, k, cos, sin):
# Split q and k into two halves along the last dimension
q1, q2 = q.split(q.shape[-1] // 2, dim=-1) # q1 and q2 will each have shape [batch_size, n_head, seq_len, 32]
k1, k2 = k.split(k.shape[-1] // 2, dim=-1)
# Print shapes to verify correctness
#print(f"q1 shape: {q1.shape}, q2 shape: {q2.shape}") # Expecting [257, 16, 36, 32] for both
#print(f"k1 shape: {k1.shape}, k2 shape: {k2.shape}") # Expecting [257, 16, 36, 32] for both
# Apply sin and cos to the respective halves
q_ = torch.cat([q1 * cos - q2 * sin, q1 * sin + q2 * cos], dim=-1)
k_ = torch.cat([k1 * cos - k2 * sin, k1 * sin + k2 * cos], dim=-1)
#print(f"q_ shape after RoPE: {q_.shape}") # Should match original q shape [257, 16, 36, 64]
#print(f"k_ shape after RoPE: {k_.shape}") # Should match original k shape [257, 16, 36, 64]
return q_, k_
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, seq_len: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", GeometricLinear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", GeometricLinear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.seq_len = seq_len
self.d_model = d_model
self.n_head = n_head
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
# Apply layer normalization
x_ln = self.ln_1(x)
# Linear projections for query, key, and value
q_proj_weight, k_proj_weight, v_proj_weight = self.attn.in_proj_weight.chunk(3)
q = F.linear(x_ln, q_proj_weight)
k = F.linear(x_ln, k_proj_weight)
v = F.linear(x_ln, v_proj_weight)
# Reshape for multi-head attention (batch_size, seq_len, n_head, head_dim)
batch_size, seq_len, _ = q.size()
head_dim = self.d_model // self.n_head
q = q.view(batch_size, seq_len, self.n_head, head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.n_head, head_dim).transpose(1, 2)
# Print shapes to debug
#print(f"q shape after view and transpose: {q.shape}") # Expecting (batch_size, n_head, seq_len, head_dim)
#print(f"k shape after view and transpose: {k.shape}") # Expecting (batch_size, n_head, seq_len, head_dim)
# Apply RoPE to the query and key
sin, cos = get_rotary_embeddings(head_dim, seq_len, x.device)
# Print shapes of sin and cos
#print(f"sin shape: {sin.shape}") # Expecting (1, 1, seq_len, head_dim/2)
#print(f"cos shape: {cos.shape}") # Expecting (1, 1, seq_len, head_dim/2)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# Reshape back to (batch_size, seq_len, d_model) after RoPE
q = q.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
k = k.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
# Print shapes after RoPE application and reshape
#print(f"q shape after RoPE and reshape: {q.shape}") # Expecting (batch_size, seq_len, d_model)
#print(f"k shape after RoPE and reshape: {k.shape}") # Expecting (batch_size, seq_len, d_model)
# Perform multi-head attention
attn_output, _ = self.attn(
q, k, v, need_weights=False, attn_mask=self.attn_mask
)
return attn_output
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, seq_len: int = 248):
super().__init__()
self.width = width
self.layers = layers
self.seq_len = seq_len
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, seq_len, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
load_from_clip: bool
):
super().__init__()
self.context_length = 248
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
seq_len=self.context_length # Pass the context length for RoPE
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
if load_from_clip == False:
self.positional_embedding = nn.Parameter(torch.empty(248, transformer_width))
self.positional_embedding_res = nn.Parameter(torch.empty(248, transformer_width))
else:
self.positional_embedding = nn.Parameter(torch.empty(77, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
self.mask1 = torch.zeros([248, 1])
self.mask1[:20, :] = 1
self.mask2 = torch.zeros([248, 1])
self.mask2[20:, :] = 1
Note, in case you think it's interesting and want to try it (with a more suitable batch size, perhaps): As SPRIGHT-T2I is GPT-4V etc. / AI labeled, I noticed the dataset contains a few "glitch labels", where the label is just a repetition of "a small airplane and a large airplane. a small airplane and a large airplane. [etc] [etc]", exceeding 248 tokens. I recommend running the entire dataset against the tokenizer and just deleting any examples that cause an error. I find about 10 out of 100,000 so far - and they are often low-quality images (e.g. images full of text), so I think it's best to just auto-delete them.
First of all, in case anybody sees this thread and thinks "Oh, I want to use Long-CLIP with Flux!": I made a ComfyUI custom node for it, and you can find it here: https://github.com/zer0int/ComfyUI-Long-CLIP
Now, about the actual topic. The developers of Flux1 have stated in their announcement that (besides being a gigantic 12B parameters diffusion transformer), their model's superior performance is also due to
rotary positional embeddings
.And indeed, the red cube on top of a blue cube problem doesn't exist for FLux. It can even accuraty generate this:
"a red cube on top of a blue cube, with a green kitten sitting on top of the red cube, the cat is holding a sign that says 'rotary positional embeddings', and in the background there are many tiny pink spheres and yellow triangles"
In terms of general spatial prompt following, Flux1 can do this (albeit details can be problematic).
Flux1 uses
T5 + CLIP ViT-L/14
as Text Encoders. Besides Long-CLIP nicely complementing the maximum sequence length of T5, I naturally also wondered: What if CLIP had RoPE?My previous MLP modification was Geometric Parametrization (GmP), which "splits" the .weight into .theta and .r, and thus preserves the learned information. Not a big deal.
However, RoPE changes how the attention mechanism works. So it needs to "learn how to see" again after this change. Nevertheless, I tried it! 🤓
First, I fine-tuned with the
COCO-SPRIGHT-40k
spatial labels dataset I used previously, with labels <77 tokens, to compare CLIP vs. Long-CLIP. The models show similar patterns in learning, but their validation acccuracy and F1 remains very poor after this - BUT it is improving, albeit slowly.So I continued fine-tuning with CC12M SPRIGHT, and using the original long captions >77 tokens for Long-CLIP.
Now, after ~150,000 text-image pairs (split into two separate runs), latest:
Start: Validation Acc: 0.0880, Validation F1: 0.0616
End: Validation Acc: 0.1537, Validation F1: 0.1492
Fine-tuned Model Accuracy on MVT ImageNet/ObjectNet: 0.79064 Down from (your model): 0.81134
So I am hoping it may be enough with "just" 1-5 million text-image pairs of re-training for CLIP to learn how to "see with its new attention".
I am using GmP and label smoothing and RoPE. GmP is necessary because I am "GPU poor", I am training on 1 RTX 4090. RoPE further increases model VRAM requirement, so I have to train using a batch size of 26 (!) - definitely NOT ideal for CLIP!
I would love to get your feedback on this idea (RoPE for CLIP in general). While GmP has shown to "just work well" empirically, I am uncertain about RoPE. So, even if you have negative feedback and criticism / if you think RoPE is a very bad idea, I would very much appreciate this feedback, too!
Any feedback is welcome! Thank you!
Supplementary images.