Closed hengseuer closed 5 months ago
Hi @hengseuer Yes. The training of the text prompt branch needs more data and a longer time to convergence, so we train the text prompt first.
Thank you for your response.
I have another question: When training text and visual prompts simultaneously, do the negative samples for the visual prompts come from the image itself, the current batch, or is there a maintained pool of negative samples?
The negative samples for visual prompts are sampled from current mini batch
Thanks a lot.
Are all the samples in the current mini-batch from the same dataset?
If, during the current iteration, all the samples across the GPUs are from the same dataset and we sample negative examples from within the entire batch, similar to the approach used in DINOv, would this result in better performance?
Our implementation only samples negative prompts from the current GPU. Using the sampling strategy in DINOv might bring more performance boosts.
Got it. Thanks.
Hi, @Mountchicken!
Am I understand correctly, that loss calculated between processed batch's images and prompts which was obtain from the same batch? So contrastive loss calculated like
ContrastiveLoss(normalize(predicted_boxes_queries) @ normalize(encoded_prompt_queries), int(0 or 1))
Thanks!
@VilisovEvgeny Almost the same. We are using sigmoid focal loss here as the classification loss because the label can be 0, 1, 2 ..., depending on the number of classes in current batch
sigmoid focal loss
Like this?:
torchvision.ops.sigmoid_focal_loss((normalize(predicted_boxes_queries) @ normalize(encoded_prompt_queries)).flatten(), [int(0 or 1)]*len(logits_num))
yes
@Mountchicken hi! I'm trying to fit my model using normalized embeddings, like in this example, and looks like using cosine similarity (between -1 and 1) is not correct way to use sigmoid focal loss. Are you used embeddings dot product without normalization or normalized?
torchvision.ops.sigmoid_focal_loss((normalize(predicted_boxes_queries) @ normalize(encoded_prompt_queries)).flatten(), [int(0 or 1)]*len(logits_num))
here is an example: def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits( inputs, targets, reduction="none") p_t = prob targets + (1 - prob) (1 - targets) loss = ce_loss * ((1 - p_t)**gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss
@Mountchicken I understand how sigmoid focal loss works. I asked about fitting with normalized or not normalized embeddings. Because when we normalize our embeddings and then calculate cosine similarity between them we obtain values/loggits between -1 and 1. And after we apply sigmoid to this values we will obtain values between ~0.2 and ~0.7. So sigmoid focal loss output will never become 0. And in your answer I didn't see information about normalization. Could you please help me with this?
@pisiguiii Sorry for the confuse. Given a visual prompt embedding (1xC) and the final 900 detection queries (900xC), we don't normalize them but directly dot-product them to get the final logits (900x1). The logins are then sigmoid to get final scores in the range between 0 and 1.
Thanks! Understandable
Hello,
I have a question about the training process of T-rex2. Does T-rex2 first train the text prompts and then train both the text and visual prompts in successive iterations?
Thank you!