IDEA-Research / T-Rex

[ECCV2024] API code for T-Rex2: Towards Generic Object Detection via Text-Visual Prompt Synergy
https://deepdataspace.com/blog/T-Rex
Other
2.26k stars 146 forks source link

About Position Encoding #84

Open pisiguiii opened 3 months ago

pisiguiii commented 3 months ago

Hi!

I want to ask, did you try to use instead of sin position encoder PE with learnable layer? If yes, how did it behave?

Also I'm interested, as I understand from paper, in final version of visual prompt processing you concatenate encoded boxes and content embeddings, so if we have encoded boxes embedding 256d and content embedding 256d => our final CAT(B, C) d will be 512. Did you try to summarize this embeddings? Like: Q = Linear(B + C) with d = 256?

Mountchicken commented 3 months ago

Hi @pisiguiii We didn't try learnable position embedding but I think it should have a similar performance with sincos position embedding. This is verified in DETR series. Specifically, In DINO, they are using sincos position embedding and we directly follow their implementation

The content embedding and position embedding will be added instead of concatnated.

yu-xi-wang commented 2 months ago

Hi @Mountchicken , thanks for the great job! I also have some questions related to this issue.

  1. In the paper, about the prompt encoder it says

    These content embeddings are concatenated with position embeddings along the channel dimension, and a linear layer is applied for projection, thereby constructing the input query embedding Q

But in the implementation, it actually first perform [C;C'] + [B;B'] result as a tensor with shape [k + 1, 256], and then feed it into a linear layer without change the shape, output the Q = linear([C;C'] + [B;B']) with shape [k + 1, 256]?

  1. In the next step, the Q will be used to extract and aggregate target regions by performing MSDeformAttn with encoded features. MSDeformAttn(Q_j ,b_j , {f_i}), in this formula, b_j is the jth box which should be the reference points, Q_j is the jth prompt query embedding. In the question 1, if I understand correctly, this Q_j already have position embedding information due to the '+'. But in the code of DINO, I saw it frequently use
    src2 = MSDeformAttn(self.with_pos_embed(src, pos), reference_points, src, ...)
    src = src + self.dropout1(src2)
    src = self.norm1(src)

    This with_pos_embed just simply return src + pos. So, I'm confused which of following implementation is correct?

  2. Q = linear( [C;C'] + [B;B'] )
    src2 = MSDeformAttn(self.with_pos_embed(Q, B), B, f, ...)
    Q = Q + self.dropout1(src2)
    Q = self.norm1(Q)
  3. Q = linear( [C;C'] + [B;B'] )
    src2 = MSDeformAttn(Q, B, f, ...)
    [C;C']  = [C;C']  + self.dropout1(src2)
    [C;C']  = self.norm1([C;C'] )

The option 1 seems add position embedding twice, and position embedding will remained in final V sounds not make sense. Could you help me understand this? Thank you!

Mountchicken commented 2 months ago

Hi @yu-xi-wang Sorry for the late reply. I checked the code again and an issue with the description of the prompt encoder in our paper. The correct calculation method should be:

Q = [C; C']
Position = [B; B']
src2 = MSDeformAttn(self.with_pos_embed(Q, Position), box_or_point_coordinates, f, ...)
Q = Q + self.dropout1(src2)
Q = self.norm1(Q)
Visual_prompt_embedding = Q[:, -1]

Content embedding and position embedding will not be concatnated but added during attention.

yu-xi-wang commented 2 months ago

Hi @Mountchicken thank you so much for the reply! Yes, it make sense to me now!

VilisovEvgeny commented 2 months ago

Hi @Mountchicken, hi @yu-xi-wang!

I tried to implement the positional encoding code as I understood it from the article, but I ran into a problem that all my encoded boxes had almost identical embeddings. Maybe you can help me understand what I'm missing?

This is my code:


def _boxes_embed(self, x):
        bs, K, D = x.size()  # K is the number of bounding boxes, D should be 4
        pe = torch.zeros(bs, K, D, self.num_pos_feats, device=x.device)

        # Create the scaling factor for positional encoding
        dim_t = torch.arange(self.num_pos_feats * 4, dtype=torch.float32, device=x.device)
        dim_t = 10000 ** (2 * (dim_t // 2) / (self.num_pos_feats * 4))

        for i in range(D):
            pos = x[:, :, i].unsqueeze(-1)  # Shape: [K, 1]
            scaled_pos = pos / dim_t[self.num_pos_feats * i:self.num_pos_feats * (i + 1)]  # Shape: [K, num_pos_feats]
            pe[:, :, i, 0::2] = torch.sin(scaled_pos[:, :, 0::2])  # Apply sine to even indices
            pe[:, :, i, 1::2] = torch.cos(scaled_pos[:, :, 1::2])  # Apply cosine to odd indices

        pe = pe.view(bs, K, -1)  # Concatenate the embeddings to get a shape of [K, 256]
        return pe
Mountchicken commented 2 months ago

@VilisovEvgeny Here is the code that I use to get the position embedding:

def gen_sineembed_for_position(pos_tensor):
    # n_query, bs, _ = pos_tensor.size()
    # sineembed_tensor = torch.zeros(n_query, bs, 256)
    scale = 2 * math.pi
    dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
    dim_t = 10000**(2 * (dim_t // 2) / 128)
    x_embed = pos_tensor[:, :, 0] * scale
    y_embed = pos_tensor[:, :, 1] * scale
    pos_x = x_embed[:, :, None] / dim_t
    pos_y = y_embed[:, :, None] / dim_t
    pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()),
                        dim=3).flatten(2)
    pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()),
                        dim=3).flatten(2)
    if pos_tensor.size(-1) == 2:
        pos = torch.cat((pos_y, pos_x), dim=2)
    elif pos_tensor.size(-1) == 4:
        w_embed = pos_tensor[:, :, 2] * scale
        pos_w = w_embed[:, :, None] / dim_t
        pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()),
                            dim=3).flatten(2)

        h_embed = pos_tensor[:, :, 3] * scale
        pos_h = h_embed[:, :, None] / dim_t
        pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()),
                            dim=3).flatten(2)

        pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
    else:
        raise ValueError("Unknown pos_tensor shape(-1):{}".format(
            pos_tensor.size(-1)))
    return pos
VilisovEvgeny commented 2 months ago

@Mountchicken thanks for provided solution!

But I'm a little confused, why does cosine similarity between obtained pos embeddings do not decreasing lower than 0.7? Is it a common behavior?

image

Mountchicken commented 2 months ago

I'm not sure if this is normal. Did you normalize the box coordinates to 0-1 before you got the position embedding?

VilisovEvgeny commented 2 months ago

yes, I did. I also checked if used boxes in cxcywh format In all my sine encoding realizations I also met similar behavior.

VilisovEvgeny commented 1 month ago

I run some tests with small part of LVIS dataset trying to check all embeddings (not only the last embedding [: -1]) and this is what i get: This is how to interpret labels: (number of image dataset sample)(class name)(number of unique box "visual prompt"/global visual prompt with box [0.5, 0.5, 1.0, 1.0])

Here I was comparing final visual prompt embeddings: image

here I was comparing pos encoded boxes embeddings obtained from function def gen_sineembed_for_position(pos_tensor): image

Comparing with DINOv repo prompt encoding part looks like MSDeformAttnTransformerEncoderLayer. I'm literally copy-past this class. And provided matrix show that global embedding similar with global embeddings (class token) from different images and have lowest similarity with visual prompt embeddings of it's own sample. How is this possible? Did you met with problems like this?

VilisovEvgeny commented 1 month ago

@Mountchicken could you help me with my issue which I described previously? Final global embeddings have much more similarity with others global embeddings then with final embeddings of their own classes. I'm following paper and use GroundingDINO DeformAttnDecoderLayer as a base.

Mountchicken commented 1 month ago

Hi @VilisovEvgeny Sorry for the late reply. Based on your visualization, if I understand correctly, you visualized the similarity between global content embeddings of different categories across different images and found that they are quite similar. Since the global content embeddings are the same before entering the visual prompt encoder, the final outputs might also be quite similar. This is indeed an issue, and we haven't discussed this problem before.

VilisovEvgeny commented 1 month ago

Thanks for your reply, @Mountchicken!

I visualized not only global embedding, but all embedding from final output (so there is one embedding for each unique sample per class per image and one global embedding). The main problem is that my global embeddings from one image for different classes is too much similar to each other, so when I'm trying to fit my pipeline it doesn't even pass sanity check on small amount of data.

This is how looks similarity between global embeddings of different classes from one and different images:

image

I understand, that this is too much to ask about, but I would be very grateful if you could tell me what the average similarity is between global embeddings obtained from different classes from the same image and from different images. Following both paper and your advises from Issues, I can't understand, am I implementing Visual Prompt Encoder correctly or not(