IDEA-Research / T-Rex

[ECCV2024] API code for T-Rex2: Towards Generic Object Detection via Text-Visual Prompt Synergy
https://deepdataspace.com/blog/T-Rex
Other
2.28k stars 148 forks source link

About the training process of T-rex2 #62

Closed hengseuer closed 5 months ago

hengseuer commented 6 months ago

Hello,

I have a question about the training process of T-rex2. Does T-rex2 first train the text prompts and then train both the text and visual prompts in successive iterations?

Thank you!

Mountchicken commented 6 months ago

Hi @hengseuer Yes. The training of the text prompt branch needs more data and a longer time to convergence, so we train the text prompt first.

hengseuer commented 6 months ago

Thank you for your response.

I have another question: When training text and visual prompts simultaneously, do the negative samples for the visual prompts come from the image itself, the current batch, or is there a maintained pool of negative samples?

Mountchicken commented 6 months ago

The negative samples for visual prompts are sampled from current mini batch

hengseuer commented 6 months ago

Thanks a lot.

Are all the samples in the current mini-batch from the same dataset?

If, during the current iteration, all the samples across the GPUs are from the same dataset and we sample negative examples from within the entire batch, similar to the approach used in DINOv, would this result in better performance?

Mountchicken commented 6 months ago

Our implementation only samples negative prompts from the current GPU. Using the sampling strategy in DINOv might bring more performance boosts.

hengseuer commented 6 months ago

Got it. Thanks.

VilisovEvgeny commented 3 months ago

Hi, @Mountchicken!

Am I understand correctly, that loss calculated between processed batch's images and prompts which was obtain from the same batch? So contrastive loss calculated like

ContrastiveLoss(normalize(predicted_boxes_queries) @ normalize(encoded_prompt_queries), int(0 or 1))

Thanks!

Mountchicken commented 3 months ago

@VilisovEvgeny Almost the same. We are using sigmoid focal loss here as the classification loss because the label can be 0, 1, 2 ..., depending on the number of classes in current batch

VilisovEvgeny commented 3 months ago

sigmoid focal loss

Like this?:

torchvision.ops.sigmoid_focal_loss((normalize(predicted_boxes_queries) @ normalize(encoded_prompt_queries)).flatten(), [int(0 or 1)]*len(logits_num))

Mountchicken commented 3 months ago

yes

pisiguiii commented 3 months ago

@Mountchicken hi! I'm trying to fit my model using normalized embeddings, like in this example, and looks like using cosine similarity (between -1 and 1) is not correct way to use sigmoid focal loss. Are you used embeddings dot product without normalization or normalized?

torchvision.ops.sigmoid_focal_loss((normalize(predicted_boxes_queries) @ normalize(encoded_prompt_queries)).flatten(), [int(0 or 1)]*len(logits_num))

Mountchicken commented 3 months ago

here is an example: def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits( inputs, targets, reduction="none") p_t = prob targets + (1 - prob) (1 - targets) loss = ce_loss * ((1 - p_t)**gamma)

if alpha >= 0:
    alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
    loss = alpha_t * loss

return loss
pisiguiii commented 3 months ago

@Mountchicken I understand how sigmoid focal loss works. I asked about fitting with normalized or not normalized embeddings. Because when we normalize our embeddings and then calculate cosine similarity between them we obtain values/loggits between -1 and 1. And after we apply sigmoid to this values we will obtain values between ~0.2 and ~0.7. So sigmoid focal loss output will never become 0. And in your answer I didn't see information about normalization. Could you please help me with this?

Mountchicken commented 3 months ago

@pisiguiii Sorry for the confuse. Given a visual prompt embedding (1xC) and the final 900 detection queries (900xC), we don't normalize them but directly dot-product them to get the final logits (900x1). The logins are then sigmoid to get final scores in the range between 0 and 1.

pisiguiii commented 3 months ago

Thanks! Understandable